diff --git a/tests/model_executor/model_loader/test_remote_instance_loader.py b/tests/model_executor/model_loader/test_remote_instance_loader.py new file mode 100644 index 000000000000..33d8f95b68f2 --- /dev/null +++ b/tests/model_executor/model_loader/test_remote_instance_loader.py @@ -0,0 +1,170 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Tests for the remote instance model loader. + +To run these tests: +1. Install test dependencies: + uv pip install -r requirements/common.txt -r requirements/dev.txt + --torch-backend=auto + uv pip install pytest pytest-asyncio + +2. Run the tests: + pytest -s -v tests/model_executor/model_loader/test_remote_instance_loader.py + +Note: This test is marked as skip because it requires: +- Multiple GPUs (at least 8 GPUs for 2x2 TP/PP configuration for both seed + and client instances) +- Coordinated seed and client servers +- Proper setup of environment variables +- Network communication between servers +""" + +from http import HTTPStatus + +import pytest +import requests +from huggingface_hub import snapshot_download + +from tests.utils import RemoteOpenAIServer + +# Test prompts +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + + +@pytest.fixture(scope="module") +def llama_3p2_1b_files(): + """Download the Llama-3.2-1B-Instruct model files for testing.""" + input_dir = snapshot_download( + "meta-llama/Llama-3.2-1B-Instruct", ignore_patterns=["*.bin*", "original/*"] + ) + yield input_dir + + +def test_remote_instance_loader_end_to_end(llama_3p2_1b_files, num_gpus_available): + """ + End-to-end test for the remote instance loader. + + This test simulates the manual testing procedure: + 1. Start a seed server (source of weights) + 2. Start a client server (loads weights from seed server) + 3. Compare outputs from both servers + + Note: This test is marked as skip because it requires: + - Multiple GPUs (at least 8 GPUs for 2x2 TP/PP configuration for both + seed and client instances) + - Coordinated seed and client servers + - Proper setup of environment variables + - Network communication between servers + """ + # Need at least 8 GPUs (4 for seed instance + 4 for client instance) + if num_gpus_available < 8: + pytest.skip( + "Not enough GPUs for 2x2 TP/PP configuration for both seed and " + "client instances (requires 8 GPUs)" + ) + + input_dir = llama_3p2_1b_files + seed_port = 12346 + client_port = 12347 + gpu_memory_utilization = 0.8 + + # Server arguments for both seed and client instances + common_args = [ + "--tensor-parallel-size", + "2", + "--pipeline-parallel-size", + "2", + "--gpu-memory-utilization", + str(gpu_memory_utilization), + "--max-model-len", + "1024", + "--enforce-eager", + ] + + # Run seed server (source of weights) + seed_args = [ + "--host", + "127.0.0.1", + "--port", + str(seed_port), + *common_args, + ] + + with RemoteOpenAIServer(input_dir, seed_args, auto_port=False) as seed_server: + # Check if seed server is running + response = requests.get(seed_server.url_for("health")) + assert response.status_code == HTTPStatus.OK + + # Run client server (loads weights from seed server) + # Set environment variables for remote instance loading + # Use different GPUs for client instance to avoid conflict with seed instance + client_env_dict = { + "REMOTE_INSTANCE_IP": "127.0.0.1", + "REMOTE_INSTANCE_SERVER_PORT": str(seed_port), + "REMOTE_INSTANCE_PORTS": "[50000,50001,50002,50003]", + "CUDA_VISIBLE_DEVICES": "4,5,6,7", # Use different GPUs for client + } + + client_args = [ + "--host", + "127.0.0.1", + "--port", + str(client_port), + "--load-format", + "remote_instance", + *common_args, + ] + + with RemoteOpenAIServer( + input_dir, client_args, env_dict=client_env_dict, auto_port=False + ) as client_server: + # Check if client server is running + response = requests.get(client_server.url_for("health")) + assert response.status_code == HTTPStatus.OK + + # Get clients for both servers + seed_client = seed_server.get_client() + client_client = client_server.get_client() + + # Get the model name from the seed server + seed_models = seed_client.models.list() + seed_model_name = seed_models.data[0].id + + # Get the model name from the client server + client_models = client_client.models.list() + client_model_name = client_models.data[0].id + + # Generate outputs from both servers and compare + for prompt in prompts: + # Generate from seed server + seed_response = seed_client.completions.create( + model=seed_model_name, + prompt=prompt, + max_tokens=256, + temperature=0.0, + ) + seed_text = seed_response.choices[0].text + + # Generate from client server + client_response = client_client.completions.create( + model=client_model_name, + prompt=prompt, + max_tokens=256, + temperature=0.0, + ) + client_text = client_response.choices[0].text + + # Compare outputs + assert seed_text == client_text, ( + f"Outputs from seed and client servers should be identical.\n" + f"Prompt: {prompt}\n" + f"Seed output: {seed_text}\n" + f"Client output: {client_text}" + ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/weight_transfer_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/weight_transfer_connector.py new file mode 100644 index 000000000000..65650e8707d8 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/weight_transfer_connector.py @@ -0,0 +1,184 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://github.com/sgl-project/sglang/pull/8215 + +from datetime import timedelta +from typing import Any + +import torch +from torch.distributed.distributed_c10d import ( + Backend, + PrefixStore, + Store, + _new_process_group_helper, + _world, + default_pg_timeout, + rendezvous, +) + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.forward_context import ForwardContext +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class WeightTransferConnector: + """weight transfer connectors for RemoteInstanceLoader.""" + + def __init__(self, url: str): + self.url = url + self.closed = False + self._model_update_group = None + + def build_group( + self, + gpu_id: int = -1, + client_rank: int = -1, + client_id: str = "", + group_rank: int = 1, + world_size: int = 2, + ): + assert gpu_id != -1 and client_rank != -1, ( + "gpu_id and tp_rank must be specified for RemoteInstanceConnector. " + ) + + self.device_id = torch.device("cuda", gpu_id) + master_address, master_port = self.url.split(":") + group_name = f"send_weights_{client_id}_{client_rank}" + backend = "nccl" + + logger.info( + "init custom process group: master_address=%s, master_port=%s, " + "rank_offset=%s, world_size=%s, group_name=%s, backend=%s, gpu_id=%s", + master_address, + master_port, + group_rank, + world_size, + group_name, + backend, + gpu_id, + ) + + try: + self._model_update_group = init_custom_process_group( + backend=backend, + init_method=f"tcp://{master_address}:{master_port}", + timeout=timedelta(seconds=60), + world_size=world_size, + rank=group_rank, + group_name=group_name, + device_id=self.device_id, + ) + + return True, "Succeeded to initialize custom process group." + except Exception as e: + message = f"Failed to initialize custom process group: {e}." + logger.error(message) + return False, message + + def close(self): + if self.closed: + return + self.closed = True + if self._model_update_group is not None: + torch.distributed.distributed_c10d.destroy_process_group( + self._model_update_group + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def __del__(self): + self.close() + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: + return + + def wait_for_layer_load(self, layer_name: str) -> None: + return + + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs: Any, + ) -> None: + return + + def wait_for_save(self): + return + + +# Copy from pytorch and OpenRLHF to allow creating multiple main groups. +# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py +# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py +def init_custom_process_group( + backend: str | None = None, + init_method: str | None = None, + timeout: timedelta | None = None, + world_size: int = -1, + rank: int = -1, + store: Store | None = None, + group_name: str = "", + pg_options: Any | None = None, + device_id: torch.device | int | None = None, +): + assert (store is None) or (init_method is None), ( + "Cannot specify both init_method and store." + ) + + if store is not None: + assert world_size > 0, "world_size must be positive if using store" + assert rank >= 0, "rank must be non-negative if using store" + elif init_method is None: + init_method = "env://" + + backend = Backend(backend) if backend else Backend("undefined") + + if timeout is None: + timeout = default_pg_timeout + + # backward compatible API + if store is None: + rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + store = PrefixStore(group_name, store) + + # Get the rank of the current process in the default process group + my_rank = torch.distributed.get_rank() + global_ranks_in_group = [my_rank] # Must include itself at least + logger.debug("global_ranks_in_group: %s", global_ranks_in_group) + + # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0 + # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844 + # We need to determine the appropriate parameter name based on PyTorch version + pg_options_param_name = ( + "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" + ) + pg, _ = _new_process_group_helper( + world_size, + rank, + global_ranks_in_group, + backend, + store, + group_name=group_name, + **{pg_options_param_name: pg_options}, + timeout=timeout, + device_id=device_id, + ) + + _world.pg_group_ranks[pg] = { + global_rank: group_rank + for group_rank, global_rank in enumerate(global_ranks_in_group) + } + logger.debug("_world: %s", _world.pg_group_ranks[pg]) + return pg diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 555c95effd1d..ea70de615892 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -62,6 +62,7 @@ EmbeddingResponse, ErrorInfo, ErrorResponse, + InitWeightsSendGroupForRemoteInstanceReqInput, IOProcessorResponse, LoadLoRAAdapterRequest, PoolingRequest, @@ -72,6 +73,7 @@ ResponsesResponse, ScoreRequest, ScoreResponse, + SendWeightsToRemoteInstanceReqInput, StreamingResponsesResponse, TokenizeRequest, TokenizeResponse, @@ -1260,6 +1262,44 @@ def load_log_config(log_config_file: str | None) -> dict | None: return None +# Adapted from https://github.com/sgl-project/sglang/pull/8215 +@router.post("/init_weights_send_group_for_remote_instance") +async def init_weights_send_group_for_remote_instance( + obj: InitWeightsSendGroupForRemoteInstanceReqInput, request: Request +): + results = await request.app.state.engine_client.collective_rpc( + "init_weights_send_group_for_remote_instance", + args=(obj.model_dump(),), + timeout=30.0, + ) + all_success = all(r["success"] for r in results) + return JSONResponse( + content={ + "success": all_success, + "message": "Initialized" if all_success else "Failed", + }, + status_code=200 if all_success else 400, + ) + + +# Adapted from https://github.com/sgl-project/sglang/pull/8215 +@router.post("/send_weights_to_remote_instance") +async def send_weights_to_remote_instance( + obj: SendWeightsToRemoteInstanceReqInput, request: Request +): + results = await request.app.state.engine_client.collective_rpc( + "send_weights_to_remote_instance", args=(obj.model_dump(),), timeout=60.0 + ) + all_success = all(r["success"] for r in results) + return JSONResponse( + content={ + "success": all_success, + "message": "Initialized" if all_success else "Failed", + }, + status_code=200 if all_success else 400, + ) + + class AuthenticationMiddleware: """ Pure ASGI middleware that authenticates each request by checking diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 0d27e6707c23..7c99790d56a8 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -3016,3 +3016,25 @@ class TranslationResponseVerbose(OpenAIBaseModel): words: list[TranslationWord] | None = None """Extracted words and their corresponding timestamps.""" + + +# Adapted from https://github.com/sgl-project/sglang/pull/8215 +class InitWeightsSendGroupForRemoteInstanceReqInput(OpenAIBaseModel): + """Request to initialize weights send group for remote instance.""" + + master_address: str = Field(description="The master address") + ports: str = Field(description="The ports for each rank's communication group") + group_rank: int = Field(description="The rank in the communication group") + world_size: int = Field(description="The world size") + group_name: str = Field(default="weight_send_group", description="The group name") + backend: str = Field(default="nccl", description="The backend") + + +# Adapted from https://github.com/sgl-project/sglang/pull/8215 +class SendWeightsToRemoteInstanceReqInput(OpenAIBaseModel): + """Request to send weights to remote instance.""" + + master_address: str = Field(description="The master address") + ports: str = Field(description="The ports for each rank's communication group") + group_name: str = Field(default="weight_send_group", description="The group name") + state_dict: int = Field(description="The state_dict() lenth") diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index 301f2d00bf40..8982a00c918d 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -13,6 +13,9 @@ from vllm.model_executor.model_loader.default_loader import DefaultModelLoader from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader +from vllm.model_executor.model_loader.remote_instance_loader import ( + RemoteInstanceModelLoader, +) from vllm.model_executor.model_loader.runai_streamer_loader import ( RunaiModelStreamerLoader, ) @@ -42,6 +45,7 @@ "safetensors", "sharded_state", "tensorizer", + "remote_instance", ] _LOAD_FORMAT_TO_MODEL_LOADER: dict[str, type[BaseModelLoader]] = { "auto": DefaultModelLoader, @@ -57,6 +61,7 @@ "safetensors": DefaultModelLoader, "sharded_state": ShardedStateLoader, "tensorizer": TensorizerLoader, + "remote_instance": RemoteInstanceModelLoader, } diff --git a/vllm/model_executor/model_loader/remote_instance_loader.py b/vllm/model_executor/model_loader/remote_instance_loader.py new file mode 100644 index 000000000000..9b1a334858d7 --- /dev/null +++ b/vllm/model_executor/model_loader/remote_instance_loader.py @@ -0,0 +1,239 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from https://github.com/sgl-project/sglang/pull/8215 + +# Changes: +# - Add support for Pipeline parallel + Tensor parallel hybrid parallelism +# - Add basic model and tensor count validation + +import os +import threading +import time + +import torch +from torch import nn + +from vllm.config import ModelConfig, VllmConfig +from vllm.config.load import LoadConfig +from vllm.distributed.kv_transfer.kv_connector.v1.weight_transfer_connector import ( + WeightTransferConnector, +) +from vllm.distributed.parallel_state import get_pp_group, get_tp_group, get_world_group +from vllm.logger import init_logger +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.utils import ( + initialize_model, + process_weights_after_loading_mla, + process_weights_after_loading_quant, +) +from vllm.utils.torch_utils import set_default_torch_dtype + +logger = init_logger(__name__) + + +class RemoteInstanceModelLoader(BaseModelLoader): + """ + Get model weights from GPUs of other vLLM instances + Only support loading weights from instance with same parallelism strategy + """ + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + + extra_config = ( + {} + if load_config.model_loader_extra_config is None + else load_config.model_loader_extra_config.copy() + ) + + if extra_config: + raise ValueError( + f"Unexpected extra config keys for load format " + f"{load_config.load_format}: " + f"{load_config.model_loader_extra_config.keys()}" + ) + + def download_model(self, model_config: ModelConfig) -> None: + raise NotImplementedError + + def load_model( + self, vllm_config: VllmConfig, model_config: ModelConfig + ) -> nn.Module: + """Load a model with the given configurations.""" + self.client_id = vllm_config.instance_id + self.trigger(model_config) + device_config = vllm_config.device_config + load_config = vllm_config.load_config + load_device = ( + device_config.device if load_config.device is None else load_config.device + ) + + target_device = torch.device(load_device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = initialize_model( + vllm_config=vllm_config, model_config=model_config + ) + process_weights_after_loading_quant(model, model_config, target_device) + begin = time.perf_counter() + self.load_weights(model, model_config) + end = time.perf_counter() + + logger.info("Loading weights on %s using %s s", load_device, end - begin) + process_weights_after_loading_mla(model, model_config) + + return model.eval() + + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: + """Load a model with the given configurations.""" + global_rank = _get_rank() + url = f"{_get_seed_instance_ip()}:{_get_instance_ports()[global_rank]}" + logger.info("url: %s", url) + + with WeightTransferConnector(url) as client: + self.load_model_from_remote_instance(model, client, model_config) + + def load_model_from_remote_instance( + self, + model: nn.Module, + client: WeightTransferConnector, + model_config: ModelConfig, + ) -> None: + start_build_group_tic = time.time() + # To support tp, pp + global_rank = _get_rank() + success, message = client.build_group( + gpu_id=torch.cuda.current_device(), + client_rank=global_rank, + client_id=self.client_id, + ) + if not success: + raise RuntimeError(f"Failed to build group for remote instance: {message}") + # Wait for rank0 to complete trigger() + get_world_group().barrier() + + end_build_group_tic = time.time() + logger.info( + "finish building group for remote instance, time used: %.4fs", + end_build_group_tic - start_build_group_tic, + ) + import threading + + from vllm.model_executor.model_loader.remote_instance_loader_utils import ( + trigger_transferring_weights_request, + ) + + if global_rank == 0: + t = threading.Thread( + target=trigger_transferring_weights_request, + args=( + _get_seed_instance_ip(), + _get_instance_service_port(), + _get_instance_ports(), + self.client_id, + sum(1 for v in model.state_dict().values() if v.numel() > 0), + ), + ) + t.start() + + try: + torch.cuda.empty_cache() + logger.info("Recv weight in %s", client._model_update_group) + start_get_weights_tic = time.time() + with set_default_torch_dtype(model_config.dtype): + state_dict = model.state_dict() + for key, tensor in state_dict.items(): + if tensor.numel(): + torch.distributed.broadcast( + tensor, + group_src=0, + group=client._model_update_group, + ) + + end_get_weights_tic = time.time() + logger.info( + "finish getting all weights from remote instance, time used: %.4fs", + end_get_weights_tic - start_get_weights_tic, + ) + torch.cuda.empty_cache() + except Exception as e: + message = f"Failed to initialize custom process group: {e}." + logger.error(message) + + def trigger(self, model_config: ModelConfig): + global_rank = _get_rank() + if global_rank != 0: + return + + from vllm.model_executor.model_loader.remote_instance_loader_utils import ( + get_remote_instance_model, + trigger_init_weights_send_group_for_remote_instance_request, + ) + + try: + remote_model_id = get_remote_instance_model( + _get_seed_instance_ip(), _get_instance_service_port() + ) + except Exception as e: + raise ValueError(f"Failed to get remote model info: {e}") from e + + if _normalize_model_id(remote_model_id) != _normalize_model_id( + model_config.model + ): + raise ValueError( + f"Model mismatch: remote model '{remote_model_id}' " + f"does not match local model '{model_config.model}'" + ) + + t = threading.Thread( + target=trigger_init_weights_send_group_for_remote_instance_request, + args=( + _get_seed_instance_ip(), + _get_instance_service_port(), + _get_instance_ports(), + self.client_id, + ), + ) + t.start() + + +def _get_seed_instance_ip() -> str: + ip = os.environ.get("REMOTE_INSTANCE_IP") + if ip is None: + raise ValueError( + "REMOTE_INSTANCE_IP environment variable is not set. " + "Please set REMOTE_INSTANCE_IP to the IP address of the remote instance." + ) + return ip + + +def _get_instance_ports() -> list[int]: + import json + + ports_str = os.environ.get( + "REMOTE_INSTANCE_PORTS", "[50000,50001,50002,50003,50004,50005,50006,50007]" + ) + return json.loads(ports_str) + + +def _get_instance_service_port() -> int: + return int(os.environ.get("REMOTE_INSTANCE_SERVER_PORT", "30000")) + + +def _get_rank() -> int: + tp_rank = get_tp_group().rank_in_group + tp_size = get_tp_group().world_size + + pp_rank = get_pp_group().rank_in_group + global_rank = pp_rank * tp_size + tp_rank + + return global_rank + + +def _normalize_model_id(model_id: str) -> str: + """Normalize model ID, remove path prefix, keep only model name""" + # If it's a path, extract the last directory name + if "/" in model_id: + return model_id.rstrip("/").split("/")[-1] + return model_id diff --git a/vllm/model_executor/model_loader/remote_instance_loader_utils.py b/vllm/model_executor/model_loader/remote_instance_loader_utils.py new file mode 100644 index 000000000000..009f0274c71f --- /dev/null +++ b/vllm/model_executor/model_loader/remote_instance_loader_utils.py @@ -0,0 +1,320 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from https://github.com/sgl-project/sglang/pull/8215 + +# Changes: +# - Add support for Pipeline parallel + Tensor parallel hybrid parallelism +# - Add basic model and tensor count validation + +import logging +import threading +import time +from typing import Any + +import requests +import torch +import torch.distributed +import torch.nn as nn + +from vllm.distributed.kv_transfer.kv_connector.v1.weight_transfer_connector import ( + init_custom_process_group, +) + +logger = logging.getLogger(__name__) + +# Dictionary to store process groups +# Format: {group_name: process_group} +_weights_send_group: dict[str, Any] = {} + + +def cleanup_thread(group_name: str, delay: float = 60.0): + # Create a thread to clean up the process group + # if it's not used within delay seconds + cleanup_thread = threading.Thread( + target=_cleanup_stale_group, args=(group_name, delay) + ) + cleanup_thread.daemon = ( + True # Set as daemon thread so it doesn't prevent program exit + ) + cleanup_thread.start() + + +def _cleanup_stale_group(group_name: str, delay: float = 60.0): + """Clean up a stale process group that has timed out.""" + try: + # Sleep for the specified delay + time.sleep(delay) + + # Check if the group still exists and clean it up + process_group = _weights_send_group.get(group_name) + if process_group is not None: + torch.distributed.distributed_c10d.destroy_process_group(process_group) + del _weights_send_group[group_name] + logger.info("Cleaned up stale process group: %s", group_name) + + except Exception as e: + logger.warning("Failed to clean up stale process group %s: %s", group_name, e) + + +def trigger_init_weights_send_group_for_remote_instance_request( + remote_seed_instance_ip: str, + remote_seed_instance_service_port: int, + send_weights_group_ports: list[int], + local_client_id: str, +): + seed_instance_service_url = ( + f"http://{remote_seed_instance_ip}:{remote_seed_instance_service_port}" + ) + + try: + requests.post( + f"{seed_instance_service_url}/init_weights_send_group_for_remote_instance", + json={ + "master_address": remote_seed_instance_ip, + "ports": (",".join(str(p) for p in send_weights_group_ports)), + "group_rank": 0, + "world_size": 2, + "group_name": f"send_weights_{local_client_id}", + "backend": "nccl", + }, + ) + except Exception as e: + logger.error( + "Failed to trigger init_weights_send_group_for_remote_instance_request to " + "seed instance %s: %s.", + seed_instance_service_url, + e, + ) + raise + + +def trigger_transferring_weights_request( + remote_seed_instance_ip: str, + remote_seed_instance_service_port: int, + send_weights_group_ports: list[int], + local_client_id: str, + tensors_nums: int, +): + seed_instance_service_url = ( + f"http://{remote_seed_instance_ip}:{remote_seed_instance_service_port}" + ) + try: + requests.post( + f"{seed_instance_service_url}/send_weights_to_remote_instance", + json={ + "master_address": remote_seed_instance_ip, + "ports": (",".join(str(p) for p in send_weights_group_ports)), + "group_name": f"send_weights_{local_client_id}", + "state_dict": f"{tensors_nums}", + }, + ) + except Exception as e: + logger.error("Failed to trigger send weights to remote instance request: %s", e) + raise + + +def get_remote_instance_model( + remote_seed_instance_ip: str, + remote_seed_instance_service_port: int, +) -> str: + # Get model information from the ready instance + seed_instance_service_url = ( + f"http://{remote_seed_instance_ip}:{remote_seed_instance_service_port}" + ) + response = requests.get(f"{seed_instance_service_url}/v1/models") + models_info = response.json() + + # Verify if the model matches + ready_model_id = models_info["data"][0]["id"] + return ready_model_id + + +def init_weights_send_group_for_remote_instance( + master_address: str, + ports: str, + group_rank: int, + world_size: int, + group_name: str, + backend: str = "nccl", +): + import time + + begin = time.perf_counter() + + assert torch.distributed.is_initialized(), ( + "Default torch process group must be initialized" + ) + assert group_name != "", "Group name cannot be empty" + from vllm.distributed.parallel_state import get_tp_group + + tp_rank = get_tp_group().rank_in_group + tp_size = get_tp_group().world_size + + from vllm.distributed.parallel_state import get_pp_group + + pp_rank = get_pp_group().rank_in_group + pp_size = get_pp_group().world_size + global_rank = pp_rank * tp_size + tp_rank + + ports_list = ports.split(",") + gpu_id = torch.cuda.current_device() + + assert len(ports_list) == tp_size * pp_size, ( + f"Expected {tp_size * pp_size} ports, but got {len(ports_list)} ports." + ) + group_port = ports_list[global_rank] + group_name = f"{group_name}_{global_rank}" + + logger.info( + "init custom process group: pp_rank=%s, tp_rank=%s, " + "gpu_id=%s, master_address=%s, master_port=%s, " + "group_rank=%s, world_size=%s, group_name=%s, " + "backend=%s", + pp_rank, + tp_rank, + gpu_id, + master_address, + group_port, + group_rank, + world_size, + group_name, + backend, + ) + + try: + from datetime import timedelta + + _weights_send_group[group_name] = init_custom_process_group( + backend=backend, + init_method=f"tcp://{master_address}:{group_port}", + timeout=timedelta(seconds=60), + world_size=world_size, + rank=group_rank, + group_name=group_name, + device_id=torch.device("cuda", gpu_id), + ) + cleanup_thread(group_name, 60) + + message = ( + f"Succeeded to init group through {master_address}:{group_port} group." + ) + end = time.perf_counter() + logger.info("init_weights_send_group_for_remote_instance using %s", end - begin) + return {"success": True, "message": message} + except Exception as e: + message = f"Failed to init group: {e}." + logger.error(message) + return {"success": False, "message": message} + + +def send_weights_to_remote_instance( + master_address: str, + ports: str, + group_name: str, + remote_tensor_nums: int, + model: nn.Module, +): + assert torch.distributed.is_initialized(), ( + "Default torch process group must be initialized" + ) + assert group_name != "", "Group name cannot be empty" + + from vllm.distributed.parallel_state import get_tp_group + + tp_rank = get_tp_group().rank_in_group + tp_size = get_tp_group().world_size + + from vllm.distributed.parallel_state import get_pp_group + + pp_rank = get_pp_group().rank_in_group + pp_size = get_pp_group().world_size + global_rank = pp_rank * tp_size + tp_rank + + ports_list = ports.split(",") + assert len(ports_list) == tp_size * pp_size, ( + f"Expected {tp_size * pp_size} ports, but got {len(ports_list)} ports." + ) + group_port = ports_list[global_rank] + group_name = f"{group_name}_{global_rank}" + + send_group = None + success = False + message = "" + + try: + # Count non-empty tensors in the model's state_dict + non_empty_count = sum(1 for v in model.state_dict().values() if v.numel() > 0) + + # Safety check: Only worker0 validates tensor count + validation_passed = True + if global_rank == 0 and remote_tensor_nums != non_empty_count: + validation_passed = False + logger.error( + "[Worker0] Tensor count mismatch between local and remote instances. " + "Local non-empty tensor count: %s, " + "Remote tensor count: %s. " + "Aborting weight broadcast for all workers.", + non_empty_count, + remote_tensor_nums, + ) + + # Broadcast validation result from worker0 to all workers + from vllm.distributed.parallel_state import get_world_group + + world_group = get_world_group() + validation_result = [validation_passed] + world_group.broadcast_object_list(validation_result, src=0) + validation_passed = validation_result[0] + + # If validation failed, all workers abort + if not validation_passed: + message = ( + f"[Worker{global_rank}] Aborting weight broadcast due to worker0 " + f"validation failure." + ) + logger.error(message) + return {"success": False, "message": message} + + # Get the process group + send_group = _weights_send_group.get(group_name) + if send_group is None: + message = ( + f"Group {group_name} not in _weights_send_group list. " + f"Please call `init_weights_send_group_for_remote_instance` first." + ) + logger.error(message) + return {"success": False, "message": message} + + logger.info("Send weight in %s", send_group) + torch.cuda.empty_cache() + state_dict = model.state_dict() + for key, tensor in state_dict.items(): + if tensor.numel(): + torch.distributed.broadcast( + tensor, + group_src=0, + group=send_group, + ) + torch.cuda.empty_cache() + success = True + message = ( + f"Succeeded to send weights through {master_address}:{group_port} " + f"{group_name}." + ) + except Exception as e: + message = f"Failed to send weights: {e}." + logger.error(message) + logger.error("Model state_dict keys: %s", list(model.state_dict().keys())) + logger.error("Number of state_dict items: %s", len(model.state_dict())) + finally: + # destroy the process group after sending weights + try: + if group_name in _weights_send_group: + del _weights_send_group[group_name] + if send_group is not None: + torch.distributed.distributed_c10d.destroy_process_group(send_group) + except Exception as cleanup_error: + logger.warning("Failed to clean up process group: %s", cleanup_error) + + return {"success": success, "message": message} diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 88dfbc33e10b..7b9f2cea58ec 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -117,6 +117,42 @@ def process_weights_after_loading( module.process_weights_after_loading(model_config.dtype) +def process_weights_after_loading_quant( + model: nn.Module, model_config: ModelConfig, target_device: torch.device +) -> None: + # to avoid circular dependency + from vllm.model_executor.model_loader.online_quantization import ( + maybe_save_metadata_and_attributes_for_weight_reloading, + ) + + maybe_save_metadata_and_attributes_for_weight_reloading(model, model_config) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if isinstance(quant_method, QuantizeMethodBase): + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + + +def process_weights_after_loading_mla( + model: nn.Module, model_config: ModelConfig +) -> None: + # Initialize post-load attention weights for both Attention and MLA. + # NOTE: Happens after other modules so we can easily decompress weights. + for _, module in model.named_modules(): + if isinstance(module, (Attention, MLAAttention)) and hasattr( + module, "process_weights_after_loading" + ): + # TODO(lucas): see if there is a way to unify the signatures + # of process_weights_after_loading + module.process_weights_after_loading(model_config.dtype) + + @contextmanager def device_loading_context(module: torch.nn.Module, target_device: torch.device): if target_device.type == "cpu": diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 32d8da5ec1c8..c73f9927f568 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -755,6 +755,33 @@ def shutdown(self) -> None: if runner := getattr(self, "model_runner", None): runner.ensure_kv_transfer_shutdown() + def init_weights_send_group_for_remote_instance(self, config: dict) -> dict: + from vllm.model_executor.model_loader.remote_instance_loader_utils import ( + init_weights_send_group_for_remote_instance, + ) + + return init_weights_send_group_for_remote_instance( + master_address=config["master_address"], + ports=config["ports"], + group_rank=config["group_rank"], + world_size=config["world_size"], + group_name=config.get("group_name", "weight_send_group"), + backend=config.get("backend", "nccl"), + ) + + def send_weights_to_remote_instance(self, config: dict) -> dict: + from vllm.model_executor.model_loader.remote_instance_loader_utils import ( + send_weights_to_remote_instance, + ) + + return send_weights_to_remote_instance( + master_address=config["master_address"], + ports=config["ports"], + group_name=config.get("group_name", "weight_send_group"), + remote_tensor_nums=config.get("state_dict", -1), + model=self.model_runner.model, + ) + def init_worker_distributed_environment( vllm_config: VllmConfig,