Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions tests/model_executor/model_loader/test_remote_instance_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""
Tests for the remote instance model loader.

To run these tests:
1. Install test dependencies:
uv pip install -r requirements/common.txt -r requirements/dev.txt
--torch-backend=auto
uv pip install pytest pytest-asyncio

2. Run the tests:
pytest -s -v tests/model_executor/model_loader/test_remote_instance_loader.py

Note: This test is marked as skip because it requires:
- Multiple GPUs (at least 8 GPUs for 2x2 TP/PP configuration for both seed
and client instances)
- Coordinated seed and client servers
- Proper setup of environment variables
- Network communication between servers
"""

from http import HTTPStatus

import pytest
import requests
from huggingface_hub import snapshot_download

from tests.utils import RemoteOpenAIServer

# Test prompts
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]


@pytest.fixture(scope="module")
def llama_3p2_1b_files():
"""Download the Llama-3.2-1B-Instruct model files for testing."""
input_dir = snapshot_download(
"meta-llama/Llama-3.2-1B-Instruct", ignore_patterns=["*.bin*", "original/*"]
)
yield input_dir


def test_remote_instance_loader_end_to_end(llama_3p2_1b_files, num_gpus_available):
"""
End-to-end test for the remote instance loader.

This test simulates the manual testing procedure:
1. Start a seed server (source of weights)
2. Start a client server (loads weights from seed server)
3. Compare outputs from both servers

Note: This test is marked as skip because it requires:
- Multiple GPUs (at least 8 GPUs for 2x2 TP/PP configuration for both
seed and client instances)
- Coordinated seed and client servers
- Proper setup of environment variables
- Network communication between servers
"""
# Need at least 8 GPUs (4 for seed instance + 4 for client instance)
if num_gpus_available < 8:
pytest.skip(
"Not enough GPUs for 2x2 TP/PP configuration for both seed and "
"client instances (requires 8 GPUs)"
)

input_dir = llama_3p2_1b_files
seed_port = 12346
client_port = 12347
gpu_memory_utilization = 0.8

# Server arguments for both seed and client instances
common_args = [
"--tensor-parallel-size",
"2",
"--pipeline-parallel-size",
"2",
"--gpu-memory-utilization",
str(gpu_memory_utilization),
"--max-model-len",
"1024",
"--enforce-eager",
]

# Run seed server (source of weights)
seed_args = [
"--host",
"127.0.0.1",
"--port",
str(seed_port),
*common_args,
]

with RemoteOpenAIServer(input_dir, seed_args, auto_port=False) as seed_server:
# Check if seed server is running
response = requests.get(seed_server.url_for("health"))
assert response.status_code == HTTPStatus.OK

# Run client server (loads weights from seed server)
# Set environment variables for remote instance loading
# Use different GPUs for client instance to avoid conflict with seed instance
client_env_dict = {
"REMOTE_INSTANCE_IP": "127.0.0.1",
"REMOTE_INSTANCE_SERVER_PORT": str(seed_port),
"REMOTE_INSTANCE_PORTS": "[50000,50001,50002,50003]",
"CUDA_VISIBLE_DEVICES": "4,5,6,7", # Use different GPUs for client
}

client_args = [
"--host",
"127.0.0.1",
"--port",
str(client_port),
"--load-format",
"remote_instance",
*common_args,
]

with RemoteOpenAIServer(
input_dir, client_args, env_dict=client_env_dict, auto_port=False
) as client_server:
# Check if client server is running
response = requests.get(client_server.url_for("health"))
assert response.status_code == HTTPStatus.OK

# Get clients for both servers
seed_client = seed_server.get_client()
client_client = client_server.get_client()

# Get the model name from the seed server
seed_models = seed_client.models.list()
seed_model_name = seed_models.data[0].id

# Get the model name from the client server
client_models = client_client.models.list()
client_model_name = client_models.data[0].id

# Generate outputs from both servers and compare
for prompt in prompts:
# Generate from seed server
seed_response = seed_client.completions.create(
model=seed_model_name,
prompt=prompt,
max_tokens=256,
temperature=0.0,
)
seed_text = seed_response.choices[0].text

# Generate from client server
client_response = client_client.completions.create(
model=client_model_name,
prompt=prompt,
max_tokens=256,
temperature=0.0,
)
client_text = client_response.choices[0].text

# Compare outputs
assert seed_text == client_text, (
f"Outputs from seed and client servers should be identical.\n"
f"Prompt: {prompt}\n"
f"Seed output: {seed_text}\n"
f"Client output: {client_text}"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://github.com/sgl-project/sglang/pull/8215

from datetime import timedelta
from typing import Any

import torch
from torch.distributed.distributed_c10d import (
Backend,
PrefixStore,
Store,
_new_process_group_helper,
_world,
default_pg_timeout,
rendezvous,
)

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger

logger = init_logger(__name__)


class WeightTransferConnector:
"""weight transfer connectors for RemoteInstanceLoader."""

def __init__(self, url: str):
self.url = url
self.closed = False
self._model_update_group = None

def build_group(
self,
gpu_id: int = -1,
client_rank: int = -1,
client_id: str = "",
group_rank: int = 1,
world_size: int = 2,
):
assert gpu_id != -1 and client_rank != -1, (
"gpu_id and tp_rank must be specified for RemoteInstanceConnector. "
)

self.device_id = torch.device("cuda", gpu_id)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use current_platform rather than cuda

master_address, master_port = self.url.split(":")
group_name = f"send_weights_{client_id}_{client_rank}"
backend = "nccl"

logger.info(
"init custom process group: master_address=%s, master_port=%s, "
"rank_offset=%s, world_size=%s, group_name=%s, backend=%s, gpu_id=%s",
master_address,
master_port,
group_rank,
world_size,
group_name,
backend,
gpu_id,
)

try:
self._model_update_group = init_custom_process_group(
backend=backend,
init_method=f"tcp://{master_address}:{master_port}",
timeout=timedelta(seconds=60),
world_size=world_size,
rank=group_rank,
group_name=group_name,
device_id=self.device_id,
)

return True, "Succeeded to initialize custom process group."
except Exception as e:
message = f"Failed to initialize custom process group: {e}."
logger.error(message)
return False, message

def close(self):
if self.closed:
return
self.closed = True
if self._model_update_group is not None:
torch.distributed.distributed_c10d.destroy_process_group(
self._model_update_group
)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
self.close()

def __del__(self):
self.close()

def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
return

def wait_for_layer_load(self, layer_name: str) -> None:
return

def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
**kwargs: Any,
) -> None:
return

def wait_for_save(self):
return


# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
def init_custom_process_group(
backend: str | None = None,
init_method: str | None = None,
timeout: timedelta | None = None,
world_size: int = -1,
rank: int = -1,
store: Store | None = None,
group_name: str = "",
pg_options: Any | None = None,
device_id: torch.device | int | None = None,
):
assert (store is None) or (init_method is None), (
"Cannot specify both init_method and store."
)

if store is not None:
assert world_size > 0, "world_size must be positive if using store"
assert rank >= 0, "rank must be non-negative if using store"
elif init_method is None:
init_method = "env://"

backend = Backend(backend) if backend else Backend("undefined")

if timeout is None:
timeout = default_pg_timeout

# backward compatible API
if store is None:
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
store, rank, world_size = next(rendezvous_iterator)
store.set_timeout(timeout)

# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
store = PrefixStore(group_name, store)

# Get the rank of the current process in the default process group
my_rank = torch.distributed.get_rank()
global_ranks_in_group = [my_rank] # Must include itself at least
logger.debug("global_ranks_in_group: %s", global_ranks_in_group)

# NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
# https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
# We need to determine the appropriate parameter name based on PyTorch version
pg_options_param_name = (
"backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
)
pg, _ = _new_process_group_helper(
world_size,
rank,
global_ranks_in_group,
backend,
store,
group_name=group_name,
**{pg_options_param_name: pg_options},
timeout=timeout,
device_id=device_id,
)

_world.pg_group_ranks[pg] = {
global_rank: group_rank
for group_rank, global_rank in enumerate(global_ranks_in_group)
}
logger.debug("_world: %s", _world.pg_group_ranks[pg])
return pg
Loading