diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index ee7c6ab0a5d7..7a65814d76c2 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -231,6 +231,7 @@ steps: - tests/compile/fullgraph/test_basic_correctness.py - examples/offline_inference/rlhf.py - examples/offline_inference/rlhf_colocate.py + - examples/offline_inference/new_weight_syncing/ - tests/examples/offline_inference/data_parallel.py - tests/v1/distributed - tests/v1/engine/test_engine_core_client.py @@ -266,10 +267,16 @@ steps: - pytest -v -s distributed/test_symm_mem_allreduce.py # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests + # OLD rlhf examples - pushd ../examples/offline_inference - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py - popd + # NEW rlhf examples + - pushd ../examples/offline_inference/new_weight_syncing + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py + - popd - label: Distributed Tests (8 GPUs) # 4min timeout_in_minutes: 10 diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index bcd9997a48f6..62a54a1f20ef 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -204,6 +204,7 @@ steps: - tests/compile/fullgraph/test_basic_correctness.py - examples/offline_inference/rlhf.py - examples/offline_inference/rlhf_colocate.py + - examples/offline_inference/new_weight_syncing/ - tests/examples/offline_inference/data_parallel.py - tests/v1/distributed - tests/v1/engine/test_engine_core_client.py @@ -238,10 +239,16 @@ steps: - pytest -v -s distributed/test_symm_mem_allreduce.py # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests + # OLD rlhf examples - pushd ../examples/offline_inference - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py - popd + # NEW rlhf examples + - pushd ../examples/offline_inference/new_weight_syncing + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py + - popd - label: Distributed Tests (8 GPUs) # 4min timeout_in_minutes: 10 @@ -1203,6 +1210,8 @@ steps: - pytest -v -s distributed/test_shm_broadcast.py - pytest -v -s distributed/test_shm_buffer.py - pytest -v -s distributed/test_shm_storage.py + - pytest -v -s distributed/test_packed_tensor.py + - pytest -v -s distributed/test_weight_transfer.py - label: 2 Node Tests (4 GPUs in total) # 16min timeout_in_minutes: 30 diff --git a/.buildkite/test_areas/distributed.yaml b/.buildkite/test_areas/distributed.yaml index 51e1de3f06ca..b37c7e05e311 100644 --- a/.buildkite/test_areas/distributed.yaml +++ b/.buildkite/test_areas/distributed.yaml @@ -63,6 +63,7 @@ steps: - tests/compile/fullgraph/test_basic_correctness.py - examples/offline_inference/rlhf.py - examples/offline_inference/rlhf_colocate.py + - examples/offline_inference/new_weight_syncing/ - tests/examples/offline_inference/data_parallel.py - tests/v1/distributed - tests/v1/engine/test_engine_core_client.py @@ -97,9 +98,14 @@ steps: - pytest -v -s distributed/test_symm_mem_allreduce.py # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests + # OLD rlhf examples - cd ../examples/offline_inference - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py + # NEW rlhf examples + - cd new_weight_syncing + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py - label: Distributed Tests (8 GPUs)(H100) timeout_in_minutes: 10 diff --git a/examples/offline_inference/new_weight_syncing/rlhf.py b/examples/offline_inference/new_weight_syncing/rlhf.py new file mode 100644 index 000000000000..b3a3ca62f5a6 --- /dev/null +++ b/examples/offline_inference/new_weight_syncing/rlhf.py @@ -0,0 +1,208 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Demonstrates reinforcement learning using vLLM and Ray, +with native weight syncing APIs at engine instance. + +The script separates training and inference workloads onto distinct GPUs +so that Ray can manage process placement and inter-process communication. +A Hugging Face Transformer model occupies one GPU for training, whereas a +2x tensor-parallel vLLM inference engine occupies two GPUs. + +The example performs the following steps: +* Load the training model on one gpu (scheduled via ray) +* Initialize the inference model with dummy weights across + two gpus using vLLM's tensor parallelism and Ray placement groups. +* Generate gibberish from a list of prompts using the randomly initialized + inference engine. +* Update the weights of the training model and broadcast the updated weights + to the inference engine by using a Ray collective RPC group. +* Generating from the list of prompts after weight sync should result + in sensible outputs. + +This example assumes a single-node cluster with three GPUs, but Ray +supports multi-node clusters. vLLM expects the GPUs are only used for vLLM +workloads. Residual GPU activity interferes with vLLM memory profiling and +causes unexpected behavior. +""" + +import os + +import ray +from ray.util.placement_group import placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from transformers import AutoModelForCausalLM + +from vllm import LLM, SamplingParams +from vllm.config import WeightTransferConfig +from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLWeightTransferEngine, +) +from vllm.utils.network_utils import get_ip, get_open_port + +MODEL_NAME = "facebook/opt-125m" +# MODEL_NAME = "inference-optimization/Qwen3-0.6B-W4A16-G128" + + +class MyLLM(LLM): + """Configure the vLLM worker for Ray placement group execution.""" + + def __init__(self, *args, **kwargs): + os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0,1" + super().__init__(*args, **kwargs) + + +@ray.remote(num_gpus=1) +class TrainModel: + """Ray actor that wraps the training model on a dedicated GPU.""" + + def __init__(self, model_name: str): + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + ).to("cuda:0") + + self.port = get_open_port() + self.master_address = get_ip() + + def get_master_address_and_port(self): + return self.master_address, self.port + + def get_weight_metadata(self): + """Return weight names, dtypes, and shapes for weight transfer.""" + names = [] + dtype_names = [] + shapes = [] + for name, p in self.model.named_parameters(): + names.append(name) + dtype_names.append(str(p.dtype).split(".")[-1]) + shapes.append(list(p.shape)) + return names, dtype_names, shapes + + def init_weight_transfer_group(self, world_size): + """Initialize the NCCL process group for weight transfer.""" + self.model_update_group = NCCLWeightTransferEngine.trainer_init( + dict( + master_address=self.master_address, + master_port=self.port, + world_size=world_size, + ), + ) + + def broadcast_weights(self, packed: bool = True): + """Broadcast weights to the inference engine.""" + NCCLWeightTransferEngine.trainer_send_weights( + iterator=self.model.named_parameters(), + group=self.model_update_group, + packed=packed, + ) + + +# Initialize Ray and set the visible devices. The vLLM engine will +# be placed on GPUs 1 and 2. +ray.init() + +# Create a placement group that reserves GPU 1–2 for the vLLM inference engine. +# Learn more about Ray placement groups: +# https://docs.ray.io/en/latest/placement-groups.html +# Launch the training model actor. Ray's resource scheduler will allocate +# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs. +train_model = TrainModel.remote(MODEL_NAME) + +pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2) +ray.get(pg_inference.ready()) +scheduling_inference = PlacementGroupSchedulingStrategy( + placement_group=pg_inference, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=0, +) + +# Launch the vLLM inference engine. The `enforce_eager` flag reduces +# start-up latency. +# Note: Weight transfer APIs (init_weight_transfer_engine, update_weights) +# are now native to vLLM workers. +llm = ray.remote( + num_cpus=0, + num_gpus=0, + scheduling_strategy=scheduling_inference, +)(MyLLM).remote( + model=MODEL_NAME, + enforce_eager=True, + tensor_parallel_size=2, + data_parallel_size=1, + distributed_executor_backend="ray", + weight_transfer_config=WeightTransferConfig(backend="nccl"), + load_format="dummy", + quantization="fp8", +) + +# Generate text from the prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +sampling_params = SamplingParams(temperature=0) + +outputs = ray.get(llm.generate.remote(prompts, sampling_params)) + +# Generate text with the initial model. The output is expected to be nonsense +# because the weights are randomly initialized. +print("-" * 50) +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) + +# Set up the communication channel between the training process and the +# inference engine. +master_address, master_port = ray.get(train_model.get_master_address_and_port.remote()) + +world_size = ray.get(llm.get_world_size.remote()) + 1 # +1 for the trainer +inference_handle = llm.init_weight_transfer_engine.remote( + dict( + init_info=dict( + master_address=master_address, + master_port=master_port, + rank_offset=1, + world_size=world_size, + ) + ) +) + +# Initialize weight transfer group on both the training actor and inference engine +train_handle = train_model.init_weight_transfer_group.remote(world_size) +ray.get([train_handle, inference_handle]) + +# Synchronize the updated weights to the inference engine using batched API. +# Collect all weight metadata from the training actor +names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote()) + +# Issue update_weights call with NCCL-specific update info +# packed=True enables efficient batched tensor broadcasting +inference_handle = llm.update_weights.remote( + dict( + update_info=dict( + names=names, + dtype_names=dtype_names, + shapes=shapes, + packed=True, + ) + ) +) + +# Broadcast all weights from trainer using the weight transfer API +train_handle = train_model.broadcast_weights.remote(packed=True) +ray.get([train_handle, inference_handle]) + +# Generate text with the updated model. The output is expected to be normal +# because the weights are updated. +outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params)) +print("-" * 50) +for output in outputs_updated: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) diff --git a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py new file mode 100644 index 000000000000..835c16a7f55c --- /dev/null +++ b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py @@ -0,0 +1,283 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Demonstrates async reinforcement learning using vLLM and Ray, +with native weight syncing APIs at engine instance. + +The script separates training and inference workloads onto distinct GPUs +so that Ray can manage process placement and inter-process communication. +A Hugging Face Transformer model occupies one GPU for training, whereas a +2x tensor-parallel vLLM inference engine occupies two GPUs. + +The example performs the following steps: +* Load the training model on one gpu (scheduled via ray) +* Initialize the inference model with dummy weights across + two gpus using vLLM's tensor parallelism and Ray placement groups. +* Generate gibberish from a list of prompts using the randomly initialized + inference engine. +* Pause generation once generation completes for one sequence +* Update the weights of the training model and broadcast the updated weights + to the inference engine by using a Ray collective RPC group. +* Resume generation and print out the results + +This example assumes a single-node cluster with three GPUs, but Ray +supports multi-node clusters. vLLM expects the GPUs are only used for vLLM +workloads. Residual GPU activity interferes with vLLM memory profiling and +causes unexpected behavior. +""" + +import os +import uuid +from dataclasses import asdict + +import ray +import torch +from ray.util.placement_group import placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from transformers import AutoModelForCausalLM, AutoTokenizer + +import vllm +from vllm import SamplingParams +from vllm.config import WeightTransferConfig +from vllm.distributed.weight_transfer.base import ( + WeightTransferInitRequest, + WeightTransferUpdateRequest, +) +from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLWeightTransferEngine, + NCCLWeightTransferInitInfo, + NCCLWeightTransferUpdateInfo, +) +from vllm.utils.network_utils import get_ip, get_open_port +from vllm.v1.executor import Executor + +MODEL_NAME = "facebook/opt-125m" + + +class MyLLM(vllm.AsyncLLMEngine): + """Configure the vLLM worker for Ray placement group execution.""" + + def __init__(self, **kwargs): + os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0,1" + engine_args = vllm.AsyncEngineArgs(**kwargs) + vllm_config = engine_args.create_engine_config() + executor_class = Executor.get_class(vllm_config) + super().__init__( + vllm_config=vllm_config, + executor_class=executor_class, + log_requests=engine_args.enable_log_requests, + log_stats=not engine_args.disable_log_stats, + ) + + async def generate_with_retry( + self, prompt_token_ids: list[int], sampling_params: vllm.SamplingParams + ) -> vllm.RequestOutput: + finish_reason = "abort" + while finish_reason == "abort": + async for request_output in self.generate( + {"prompt_token_ids": prompt_token_ids}, + sampling_params, + request_id=str(uuid.uuid4()), + ): + output = request_output + finish_reason = output.outputs[0].finish_reason + if finish_reason == "abort": + print( + f"ABORT, prompt_token_ids: {prompt_token_ids}, " + f"generated token_ids: {list(output.outputs[0].token_ids)}" + ) + prompt_token_ids = prompt_token_ids + list(output.outputs[0].token_ids) + return output + + +@ray.remote(num_gpus=1) +class TrainModel: + """Ray actor that wraps the training model on a dedicated GPU.""" + + def __init__(self, model_name: str): + self.model = AutoModelForCausalLM.from_pretrained( + model_name, dtype=torch.bfloat16 + ).to("cuda:0") + self.port = get_open_port() + self.master_address = get_ip() + + def get_master_address_and_port(self): + return self.master_address, self.port + + def get_weight_metadata(self): + """Return weight names, dtypes, and shapes for weight transfer.""" + names = [] + dtype_names = [] + shapes = [] + for name, p in self.model.named_parameters(): + names.append(name) + dtype_names.append(str(p.dtype).split(".")[-1]) + shapes.append(list(p.shape)) + return names, dtype_names, shapes + + def init_weight_transfer_group(self, world_size): + """Initialize the NCCL process group for weight transfer.""" + self.model_update_group = NCCLWeightTransferEngine.trainer_init( + dict( + master_address=self.master_address, + master_port=self.port, + world_size=world_size, + ), + ) + + def broadcast_weights(self, packed: bool = True): + """Broadcast weights to the inference engine.""" + NCCLWeightTransferEngine.trainer_send_weights( + iterator=self.model.named_parameters(), + group=self.model_update_group, + packed=packed, + ) + + +# Initialize Ray and set the visible devices. The vLLM engine will +# be placed on GPUs 1 and 2. +ray.init() + +# Launch the training model actor. Ray's resource scheduler will allocate +# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs. +train_model = TrainModel.remote(MODEL_NAME) + +# Create a placement group that reserves GPU 1–2 for the vLLM inference engine. +# Learn more about Ray placement groups: +# https://docs.ray.io/en/latest/placement-groups.html + +pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2) +ray.get(pg_inference.ready()) +scheduling_inference = PlacementGroupSchedulingStrategy( + placement_group=pg_inference, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=0, +) + +# Launch the vLLM inference engine. The `enforce_eager` flag reduces +# start-up latency. +# Note: Weight transfer APIs (init_weight_transfer_engine, update_weights) +# are now native to vLLM workers. +llm = ray.remote( + num_cpus=0, + num_gpus=0, + scheduling_strategy=scheduling_inference, +)(MyLLM).remote( + model=MODEL_NAME, + enforce_eager=True, + tensor_parallel_size=2, + distributed_executor_backend="ray", + load_format="dummy", + weight_transfer_config=WeightTransferConfig(backend="nccl"), +) + +# Generate text from the prompts. +prompts = [ + "My name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +# Tokenize prompts to token IDs +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +prompt_token_ids_list = [ + tokenizer.encode(prompt, add_special_tokens=False) for prompt in prompts +] + +sampling_params = [ + SamplingParams(temperature=0, max_tokens=2), + SamplingParams(temperature=0, max_tokens=32), + SamplingParams(temperature=0, max_tokens=32), + SamplingParams(temperature=0, max_tokens=32), +] + +# Set up the communication channel between the training process and the +# inference engine. +master_address, master_port = ray.get(train_model.get_master_address_and_port.remote()) + +world_size = 3 # 1 trainer + 2 inference workers (tensor_parallel_size=2) +inference_handle = llm.init_weight_transfer_engine.remote( + WeightTransferInitRequest( + init_info=asdict( + NCCLWeightTransferInitInfo( + master_address=master_address, + master_port=master_port, + rank_offset=1, + world_size=world_size, + ) + ) + ) +) + +# Initialize weight transfer group on both the training actor and inference engine +train_handle = train_model.init_weight_transfer_group.remote(world_size) +ray.get([train_handle, inference_handle]) + + +generation_futures = [ + llm.generate_with_retry.remote(prompt_token_ids, params) + for prompt_token_ids, params in zip(prompt_token_ids_list, sampling_params) +] + +finished, pending = ray.wait(generation_futures, num_returns=1) + +# Pause generation in preparation for weight sync +ray.get(llm.pause_generation.remote(wait_for_inflight_requests=False)) + +# Synchronize the updated weights to the inference engine using batched API. +# Collect all weight metadata from the training actor +names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote()) + +# Issue update_weights call with NCCL-specific update info +# packed=True enables efficient batched tensor broadcasting +inference_handle = llm.update_weights.remote( + WeightTransferUpdateRequest( + update_info=asdict( + NCCLWeightTransferUpdateInfo( + names=names, + dtype_names=dtype_names, + shapes=shapes, + packed=True, + ) + ) + ) +) + +# Broadcast all weights from trainer using the weight transfer API +train_handle = train_model.broadcast_weights.remote(packed=True) +ray.get([train_handle, inference_handle]) + +# Resume generation since weight sync is complete +ray.get(llm.resume_generation.remote()) + +# Get outputs separately - finished completed before pause, pending were paused/resumed +finished_outputs = ray.get(finished) +pending_outputs = ray.get(pending) + +# Requests that finished before the pause: all generation used original weights +print("-" * 50) +print("Requests that completed BEFORE weight change:") +print("-" * 50) +for output in finished_outputs: + prompt_text = tokenizer.decode(output.prompt_token_ids) + print(f"Prompt: {prompt_text!r}") + print(f"Generated (with original weights): {output.outputs[0].text!r}") + print("-" * 50) + +# Requests that were paused mid-generation: some text before, some after weight change +print("Requests that were PAUSED and RESUMED after weight change:") +print("-" * 50) +for output in pending_outputs: + # Decode the full prompt token IDs (original + generated before pause) + full_prompt_text = tokenizer.decode(output.prompt_token_ids) + # Find the original prompt by checking which one this output started with + original_prompt = next(p for p in prompts if full_prompt_text.startswith(p)) + # output.prompt_token_ids contains original prompt + tokens generated before pause + # output.outputs[0].text is what was generated after resuming with new weights + text_before_pause = full_prompt_text[len(original_prompt) :] + text_after_pause = output.outputs[0].text + print(f"Original prompt: {original_prompt!r}") + print(f"Generated before weight change: {text_before_pause!r}") + print(f"Generated after weight change: {text_after_pause!r}") + print("-" * 50) diff --git a/examples/online_serving/rlhf_http.py b/examples/online_serving/rlhf_http.py new file mode 100644 index 000000000000..721a038a6600 --- /dev/null +++ b/examples/online_serving/rlhf_http.py @@ -0,0 +1,241 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Demonstrates reinforcement learning from human feedback (RLHF) using vLLM +via HTTP API, with native weight syncing APIs. + +Unlike rlhf.py which creates a vLLM instance programmatically, this script +assumes you have already started a vLLM server using `vllm serve`. It uses: +- OpenAI-compatible API for inference requests +- HTTP endpoints for weight transfer control plane +- NCCL for actual weight data transfer + +Prerequisites: + Start a vLLM server with weight transfer enabled: + + $ VLLM_SERVER_DEV_MODE=1 vllm serve facebook/opt-125m \ + --enforce-eager \ + --weight-transfer-config '{"backend": "nccl"}' \ + --load-format dummy + + Then run this script: + + $ python rlhf_http.py + +The example performs the following steps: + +* Load the training model on GPU 0. +* Generate text using the vLLM server via OpenAI-compatible API. The output + is expected to be nonsense because the server is initialized with dummy weights. +* Initialize weight transfer via HTTP endpoint. +* Broadcast the real weights from the training model to the vLLM server + using NCCL. +* Generate text again to show normal output after the weight update. +""" + +import requests +import torch +from openai import OpenAI +from transformers import AutoModelForCausalLM + +from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLWeightTransferEngine, +) +from vllm.utils.network_utils import get_ip, get_open_port + +BASE_URL = "http://localhost:8000" +MODEL_NAME = "facebook/opt-125m" + + +def generate_completions(client: OpenAI, model: str, prompts: list[str]) -> list[str]: + """Generate completions using the OpenAI-compatible API.""" + results = [] + for prompt in prompts: + response = client.completions.create( + model=model, + prompt=prompt, + max_tokens=32, + temperature=0, + ) + results.append(response.choices[0].text) + return results + + +def init_weight_transfer_engine( + base_url: str, + master_address: str, + master_port: int, + rank_offset: int, + world_size: int, +) -> None: + """Initialize weight transfer via HTTP endpoint.""" + url = f"{base_url}/init_weight_transfer_engine" + payload = { + "init_info": dict( + master_address=master_address, + master_port=master_port, + rank_offset=rank_offset, + world_size=world_size, + ) + } + response = requests.post(url, json=payload, timeout=60) + response.raise_for_status() + + +def update_weights( + base_url: str, + names: list[str], + dtype_names: list[str], + shapes: list[list[int]], + packed: bool = False, +) -> None: + """Update weights via HTTP endpoint.""" + url = f"{base_url}/update_weights" + payload = { + "update_info": dict( + names=names, + dtype_names=dtype_names, + shapes=shapes, + packed=packed, + ) + } + response = requests.post(url, json=payload, timeout=300) + response.raise_for_status() + + +def pause_generation(base_url: str) -> None: + """Pause generation via HTTP endpoint.""" + url = f"{base_url}/pause" + response = requests.post(url, timeout=60) + response.raise_for_status() + + +def resume_generation(base_url: str) -> None: + """Resume generation via HTTP endpoint.""" + url = f"{base_url}/resume" + response = requests.post(url, timeout=60) + response.raise_for_status() + + +def get_world_size(base_url: str) -> int: + """Get world size from the vLLM server.""" + url = f"{base_url}/get_world_size" + response = requests.get(url, timeout=10) + response.raise_for_status() + return response.json()["world_size"] + + +def main(): + # Get the inference world size from the vLLM server + inference_world_size = get_world_size(BASE_URL) + world_size = inference_world_size + 1 # +1 for the trainer + device = f"cuda:{inference_world_size}" + torch.cuda.set_device(device) + + # Load the training model + print(f"Loading training model: {MODEL_NAME}") + train_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16) + train_model.to(device) + + # Create OpenAI client pointing to the vLLM server + client = OpenAI( + base_url=f"{BASE_URL}/v1", + api_key="EMPTY", # vLLM doesn't require an API key by default + ) + + # Test prompts + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + # Generate text before weight update. The output is expected to be nonsense + # because the server is initialized with dummy weights. + print("-" * 50) + print("Generating text BEFORE weight update (expect nonsense):") + print("-" * 50) + outputs = generate_completions(client, MODEL_NAME, prompts) + for prompt, generated_text in zip(prompts, outputs): + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) + + # Set up the communication channel between the training process and the + # vLLM server. The trainer is rank 0, vLLM worker(s) start at rank_offset. + master_address = get_ip() + master_port = get_open_port() + rank_offset = 1 + + print(f"Initializing weight transfer: master={master_address}:{master_port}") + + # Initialize weight transfer on vLLM server (this is async, server will + # wait for NCCL connection) + import threading + + init_thread = threading.Thread( + target=init_weight_transfer_engine, + args=(BASE_URL, master_address, master_port, rank_offset, world_size), + ) + init_thread.start() + + # Initialize NCCL process group on trainer side + model_update_group = NCCLWeightTransferEngine.trainer_init( + dict( + master_address=master_address, + master_port=master_port, + world_size=world_size, + ), + ) + + # Wait for init_weight_transfer_engine to complete + init_thread.join() + + # Pause generation before weight sync + pause_generation(BASE_URL) + + # Collect weight metadata for the update request + names = [] + dtype_names = [] + shapes = [] + for name, p in train_model.named_parameters(): + names.append(name) + dtype_names.append(str(p.dtype).split(".")[-1]) + shapes.append(list(p.shape)) + + # Start the update_weights call in a separate thread since it will block + # waiting for NCCL broadcasts + # packed=True enables efficient batched tensor broadcasting + update_thread = threading.Thread( + target=update_weights, + args=(BASE_URL, names, dtype_names, shapes, True), # packed=True + ) + update_thread.start() + + # Broadcast all weights from trainer to vLLM workers + print("Broadcasting weights via NCCL...") + NCCLWeightTransferEngine.trainer_send_weights( + iterator=train_model.named_parameters(), + group=model_update_group, + packed=True, + ) + + # Wait for update_weights to complete + update_thread.join() + + # Resume generation after weight sync + resume_generation(BASE_URL) + + # Generate text after weight update. The output is expected to be normal + # because the real weights are now loaded. + print("-" * 50) + print("Generating text AFTER weight update:") + print("-" * 50) + outputs_updated = generate_completions(client, MODEL_NAME, prompts) + for prompt, generated_text in zip(prompts, outputs_updated): + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) + + +if __name__ == "__main__": + main() diff --git a/tests/distributed/test_packed_tensor.py b/tests/distributed/test_packed_tensor.py new file mode 100644 index 000000000000..134629e2b790 --- /dev/null +++ b/tests/distributed/test_packed_tensor.py @@ -0,0 +1,443 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for packed tensor broadcasting functionality. + +Unit tests for packed_broadcast_producer and packed_broadcast_consumer. +These utilities enable efficient batched tensor transfer over NCCL. +""" + +import pytest +import torch + +from vllm.distributed.weight_transfer.nccl_engine import NCCLWeightTransferUpdateInfo +from vllm.distributed.weight_transfer.packed_tensor import ( + packed_broadcast_consumer, + packed_broadcast_producer, +) + + +class MockCommunicationGroup: + """Mock communication group for testing producer broadcast operations.""" + + def __init__(self): + self.broadcasted_tensors: list[torch.Tensor] = [] + self.broadcast_count = 0 + self.device = torch.device("cuda:0") + + def broadcast(self, tensor, src): + """Mock broadcast that stores the tensor for later verification.""" + self.broadcasted_tensors.append(tensor.clone()) + self.broadcast_count += 1 + + +class MockConsumerCommunicationGroup: + """Mock communication group for consumer that returns pre-stored tensors.""" + + def __init__(self, tensors_to_return: list[torch.Tensor]): + self.tensors_to_return = tensors_to_return + self.current_index = 0 + self.device = torch.device("cuda:0") + + def broadcast(self, tensor, src): + """Mock broadcast that fills the tensor with pre-stored data.""" + if self.current_index < len(self.tensors_to_return): + tensor.copy_(self.tensors_to_return[self.current_index]) + self.current_index += 1 + + +def create_mock_model_params( + num_layers: int = 3, + dtype: torch.dtype = torch.float32, +) -> list[tuple[str, torch.Tensor]]: + """Create mock model parameters for testing.""" + params = [] + for i in range(num_layers): + params.append((f"layer{i}.weight", torch.randn(10, 20, dtype=dtype))) + params.append((f"layer{i}.bias", torch.randn(10, dtype=dtype))) + return params + + +def create_state_dict_info( + params: list[tuple[str, torch.Tensor]], +) -> dict[str, tuple[tuple[int, ...], torch.dtype]]: + """Create state dict info (name -> (shape, dtype)) from params.""" + return {name: (tuple(tensor.shape), tensor.dtype) for name, tensor in params} + + +# --- Unit Tests: NCCLWeightTransferUpdateInfo packed field --- + + +class TestNCCLWeightTransferUpdateInfoPacked: + """Test NCCLWeightTransferUpdateInfo dataclass packed field.""" + + def test_packed_default_false(self): + """Test that packed defaults to False.""" + info = NCCLWeightTransferUpdateInfo( + names=["layer.weight"], + dtype_names=["float32"], + shapes=[[10, 10]], + ) + assert info.packed is False + + def test_packed_can_be_set_true(self): + """Test that packed can be set to True.""" + info = NCCLWeightTransferUpdateInfo( + names=["layer.weight"], + dtype_names=["float32"], + shapes=[[10, 10]], + packed=True, + ) + assert info.packed is True + + +# --- Unit Tests: packed_broadcast_producer --- + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestPackedBroadcastProducer: + """Test packed_broadcast_producer function.""" + + def test_producer_broadcasts_tensors(self): + """Test that producer broadcasts all tensors.""" + params = create_mock_model_params() + params_cuda = [(name, tensor.cuda()) for name, tensor in params] + + mock_group = MockCommunicationGroup() + + # Use a small target size to force multiple batches + packed_broadcast_producer( + iterator=iter(params_cuda), + group=mock_group, + src=0, + post_iter_func=lambda x: x[1], + buffer_size_bytes=500, + ) + + # Should have broadcasted some tensors + assert mock_group.broadcast_count > 0 + assert len(mock_group.broadcasted_tensors) > 0 + + def test_producer_single_large_tensor(self): + """Test with a single tensor larger than target size.""" + # Create a large tensor + large_tensor = torch.randn(1000, 1000, dtype=torch.float32).cuda() + params = [("large_weight", large_tensor)] + + mock_group = MockCommunicationGroup() + + # Small target size to force the tensor to exceed it + packed_broadcast_producer( + iterator=iter(params), + group=mock_group, + src=0, + post_iter_func=lambda x: x[1], + buffer_size_bytes=100, + ) + + # Should still broadcast the tensor (at least 1 broadcast) + assert mock_group.broadcast_count >= 1 + assert len(mock_group.broadcasted_tensors) >= 1 + + # Verify the total broadcasted size matches the tensor + expected_size = large_tensor.numel() * large_tensor.element_size() + actual_size = sum(t.numel() for t in mock_group.broadcasted_tensors) + assert actual_size == expected_size + + def test_producer_multiple_batches(self): + """Test that tensors are properly batched when exceeding target size.""" + # Create many small tensors + params = [ + (f"weight_{i}", torch.randn(10, 10, dtype=torch.float32).cuda()) + for i in range(20) + ] + + mock_group = MockCommunicationGroup() + + # Small target size to force multiple batches + packed_broadcast_producer( + iterator=iter(params), + group=mock_group, + src=0, + post_iter_func=lambda x: x[1], + buffer_size_bytes=2000, + ) + + # Should have multiple broadcasts + assert mock_group.broadcast_count > 1 + + # Total size should match sum of all tensors + expected_total = sum(t.numel() * t.element_size() for _, t in params) + actual_total = sum(t.numel() for t in mock_group.broadcasted_tensors) + assert actual_total == expected_total + + def test_producer_empty_iterator(self): + """Test producer handles empty iterator gracefully.""" + mock_group = MockCommunicationGroup() + + packed_broadcast_producer( + iterator=iter([]), + group=mock_group, + src=0, + post_iter_func=lambda x: x[1], + buffer_size_bytes=1000, + ) + + # No broadcasts for empty iterator + assert mock_group.broadcast_count == 0 + + +# --- Unit Tests: packed_broadcast_consumer --- + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestPackedBroadcastConsumer: + """Test packed_broadcast_consumer function.""" + + def test_consumer_receives_tensors(self): + """Test that consumer receives and unpacks tensors.""" + params = create_mock_model_params() + params_cuda = [(name, tensor.cuda()) for name, tensor in params] + + buffer_size = 2000 + + # First, run producer to get the broadcasted tensors + producer_group = MockCommunicationGroup() + + packed_broadcast_producer( + iterator=iter(params_cuda), + group=producer_group, + src=0, + post_iter_func=lambda x: x[1], + buffer_size_bytes=buffer_size, + ) + + # Now run consumer with the broadcasted tensors + consumer_group = MockConsumerCommunicationGroup( + producer_group.broadcasted_tensors + ) + + state_dict_info = create_state_dict_info(params_cuda) + + unpacked_tensors = {} + + def post_unpack_func(tensor_list): + for name, tensor in tensor_list: + unpacked_tensors[name] = tensor.clone() + + packed_broadcast_consumer( + iterator=iter(state_dict_info.items()), + group=consumer_group, + src=0, + post_unpack_func=post_unpack_func, + buffer_size_bytes=buffer_size, + ) + + # Verify all parameters were unpacked + assert len(unpacked_tensors) == len(params) + + # Verify each tensor matches the original + for name, original_tensor in params_cuda: + assert name in unpacked_tensors + unpacked = unpacked_tensors[name] + assert unpacked.shape == original_tensor.shape + assert unpacked.dtype == original_tensor.dtype + assert torch.allclose(unpacked, original_tensor, rtol=1e-5, atol=1e-7) + + +# --- Integration Tests: Producer-Consumer Roundtrip --- + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestPackedBroadcastRoundtrip: + """Test producer-consumer roundtrip behavior.""" + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) + def test_roundtrip_different_dtypes(self, dtype): + """Test roundtrip with different data types.""" + params = create_mock_model_params(num_layers=2, dtype=dtype) + params_cuda = [(name, tensor.cuda()) for name, tensor in params] + + buffer_size = 1000 + producer_group = MockCommunicationGroup() + + packed_broadcast_producer( + iterator=iter(params_cuda), + group=producer_group, + src=0, + post_iter_func=lambda x: x[1], + buffer_size_bytes=buffer_size, + ) + + consumer_group = MockConsumerCommunicationGroup( + producer_group.broadcasted_tensors + ) + + state_dict_info = create_state_dict_info(params_cuda) + unpacked_tensors = {} + + def post_unpack_func(tensor_list): + for name, tensor in tensor_list: + unpacked_tensors[name] = tensor.clone() + + packed_broadcast_consumer( + iterator=iter(state_dict_info.items()), + group=consumer_group, + src=0, + post_unpack_func=post_unpack_func, + buffer_size_bytes=buffer_size, + ) + + # Verify roundtrip preserves data + for name, original_tensor in params_cuda: + assert name in unpacked_tensors + unpacked = unpacked_tensors[name] + assert unpacked.dtype == dtype + assert torch.allclose(unpacked, original_tensor, rtol=1e-4, atol=1e-6) + + def test_roundtrip_mixed_dtypes(self): + """Test roundtrip with mixed data types.""" + # Create params with mixed dtypes + params = [ + ("layer1.weight", torch.randn(10, 20, dtype=torch.float32).cuda()), + ("layer1.bias", torch.randn(10, dtype=torch.float16).cuda()), + ("layer2.weight", torch.randn(20, 30, dtype=torch.bfloat16).cuda()), + ] + + buffer_size = 500 + producer_group = MockCommunicationGroup() + + packed_broadcast_producer( + iterator=iter(params), + group=producer_group, + src=0, + post_iter_func=lambda x: x[1], + buffer_size_bytes=buffer_size, + ) + + consumer_group = MockConsumerCommunicationGroup( + producer_group.broadcasted_tensors + ) + + state_dict_info = create_state_dict_info(params) + unpacked_tensors = {} + + def post_unpack_func(tensor_list): + for name, tensor in tensor_list: + unpacked_tensors[name] = tensor.clone() + + packed_broadcast_consumer( + iterator=iter(state_dict_info.items()), + group=consumer_group, + src=0, + post_unpack_func=post_unpack_func, + buffer_size_bytes=buffer_size, + ) + + # Verify all params roundtrip correctly with correct dtypes + for name, original_tensor in params: + assert name in unpacked_tensors + unpacked = unpacked_tensors[name] + assert unpacked.shape == original_tensor.shape + assert unpacked.dtype == original_tensor.dtype + assert torch.allclose(unpacked, original_tensor, rtol=1e-4, atol=1e-6) + + @pytest.mark.parametrize("target_size", [100, 1000, 10000, 100000]) + def test_roundtrip_different_batch_sizes(self, target_size): + """Test roundtrip with different target batch sizes.""" + params = create_mock_model_params(num_layers=5) + params_cuda = [(name, tensor.cuda()) for name, tensor in params] + + producer_group = MockCommunicationGroup() + + packed_broadcast_producer( + iterator=iter(params_cuda), + group=producer_group, + src=0, + post_iter_func=lambda x: x[1], + buffer_size_bytes=target_size, + ) + + consumer_group = MockConsumerCommunicationGroup( + producer_group.broadcasted_tensors + ) + + state_dict_info = create_state_dict_info(params_cuda) + unpacked_tensors = {} + + def post_unpack_func(tensor_list): + for name, tensor in tensor_list: + unpacked_tensors[name] = tensor.clone() + + packed_broadcast_consumer( + iterator=iter(state_dict_info.items()), + group=consumer_group, + src=0, + post_unpack_func=post_unpack_func, + buffer_size_bytes=target_size, + ) + + # Verify all params roundtrip correctly + assert len(unpacked_tensors) == len(params) + for name, original_tensor in params_cuda: + assert name in unpacked_tensors + assert torch.allclose( + unpacked_tensors[name], original_tensor, rtol=1e-5, atol=1e-7 + ) + + def test_roundtrip_non_contiguous_tensors(self): + """Test roundtrip with non-contiguous tensors from the trainer.""" + # Create non-contiguous tensors (simulating trainer outputs) + # Transposed tensors are non-contiguous + weight1 = torch.randn(20, 10, dtype=torch.float32).cuda().T + # Sliced tensors with step are non-contiguous + weight2 = torch.randn(40, 30, dtype=torch.float16).cuda()[::2, ::2] + # Permuted tensors are non-contiguous + weight3 = torch.randn(5, 10, 15, dtype=torch.bfloat16).cuda().permute(2, 0, 1) + + params = [ + ("layer1.weight", weight1), + ("layer2.weight", weight2), + ("layer3.weight", weight3), + ] + + # Verify tensors are indeed non-contiguous + for name, tensor in params: + assert not tensor.is_contiguous(), f"{name} should be non-contiguous" + + buffer_size = 500 + producer_group = MockCommunicationGroup() + + packed_broadcast_producer( + iterator=iter(params), + group=producer_group, + src=0, + post_iter_func=lambda x: x[1], + buffer_size_bytes=buffer_size, + ) + + consumer_group = MockConsumerCommunicationGroup( + producer_group.broadcasted_tensors + ) + + state_dict_info = create_state_dict_info(params) + unpacked_tensors = {} + + def post_unpack_func(tensor_list): + for name, tensor in tensor_list: + unpacked_tensors[name] = tensor.clone() + + packed_broadcast_consumer( + iterator=iter(state_dict_info.items()), + group=consumer_group, + src=0, + post_unpack_func=post_unpack_func, + buffer_size_bytes=buffer_size, + ) + + # Verify all non-contiguous params roundtrip correctly + for name, original_tensor in params: + assert name in unpacked_tensors + unpacked = unpacked_tensors[name] + assert unpacked.shape == original_tensor.shape + assert unpacked.dtype == original_tensor.dtype + assert torch.allclose(unpacked, original_tensor, rtol=1e-4, atol=1e-6) diff --git a/tests/distributed/test_weight_transfer.py b/tests/distributed/test_weight_transfer.py new file mode 100644 index 000000000000..4c348dd799b5 --- /dev/null +++ b/tests/distributed/test_weight_transfer.py @@ -0,0 +1,346 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for weight transfer engine backends. + +Unit tests for engine classes (parsing, validation, registry). +Integration test for NCCL weight transfer between processes using Ray. +""" + +from unittest.mock import MagicMock + +import pytest +import ray +import torch + +from vllm.config.parallel import ParallelConfig +from vllm.config.weight_transfer import WeightTransferConfig +from vllm.distributed.weight_transfer import WeightTransferEngineFactory +from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLWeightTransferEngine, + NCCLWeightTransferInitInfo, + NCCLWeightTransferUpdateInfo, +) +from vllm.utils.network_utils import get_open_port + + +def create_mock_parallel_config( + rank: int = 0, + world_size: int = 1, + dp_rank: int = 0, +) -> ParallelConfig: + """Create a mock ParallelConfig for testing.""" + config = MagicMock(spec=ParallelConfig) + config.rank = rank + config.world_size = world_size + config.data_parallel_rank = dp_rank + return config + + +# --- Unit Tests: NCCLWeightTransferUpdateInfo Validation --- + + +class TestNCCLWeightTransferUpdateInfoValidation: + """Test NCCLWeightTransferUpdateInfo dataclass validation.""" + + def test_valid_update_info(self): + """Test creating valid NCCLWeightTransferUpdateInfo.""" + info = NCCLWeightTransferUpdateInfo( + names=["layer.weight", "layer.bias"], + dtype_names=["float32", "float32"], + shapes=[[10, 10], [10]], + ) + assert info.names == ["layer.weight", "layer.bias"] + assert info.dtype_names == ["float32", "float32"] + assert info.shapes == [[10, 10], [10]] + + def test_mismatched_dtype_names_raises(self): + """Test that mismatched dtype_names length raises ValueError.""" + with pytest.raises(ValueError, match="dtype_names"): + NCCLWeightTransferUpdateInfo( + names=["layer.weight", "layer.bias"], + dtype_names=["float32"], # Only one dtype + shapes=[[10, 10], [10]], + ) + + def test_mismatched_shapes_raises(self): + """Test that mismatched shapes length raises ValueError.""" + with pytest.raises(ValueError, match="shapes"): + NCCLWeightTransferUpdateInfo( + names=["layer.weight", "layer.bias"], + dtype_names=["float32", "float32"], + shapes=[[10, 10]], # Only one shape + ) + + def test_empty_lists_valid(self): + """Test that empty lists are valid.""" + info = NCCLWeightTransferUpdateInfo( + names=[], + dtype_names=[], + shapes=[], + ) + assert len(info.names) == 0 + + +# --- Unit Tests: Engine Parsing --- + + +class TestNCCLEngineParsing: + """Test NCCLWeightTransferEngine parsing methods.""" + + def test_parse_init_info_valid(self): + """Test parsing valid init info dict.""" + config = WeightTransferConfig(backend="nccl") + parallel_config = create_mock_parallel_config() + engine = NCCLWeightTransferEngine(config, parallel_config) + + init_info = engine.parse_init_info( + { + "master_address": "127.0.0.1", + "master_port": 12345, + "rank_offset": 1, + "world_size": 3, + } + ) + + assert isinstance(init_info, NCCLWeightTransferInitInfo) + assert init_info.master_address == "127.0.0.1" + assert init_info.master_port == 12345 + assert init_info.rank_offset == 1 + assert init_info.world_size == 3 + + def test_parse_init_info_missing_field_raises(self): + """Test parsing init info with missing required field.""" + config = WeightTransferConfig(backend="nccl") + parallel_config = create_mock_parallel_config() + engine = NCCLWeightTransferEngine(config, parallel_config) + + with pytest.raises(ValueError, match="Invalid init_info"): + engine.parse_init_info( + { + "master_address": "127.0.0.1", + # Missing master_port, rank_offset, world_size + } + ) + + def test_parse_update_info_valid(self): + """Test parsing valid update info dict.""" + config = WeightTransferConfig(backend="nccl") + parallel_config = create_mock_parallel_config() + engine = NCCLWeightTransferEngine(config, parallel_config) + + update_info = engine.parse_update_info( + { + "names": ["w1", "w2"], + "dtype_names": ["float32", "bfloat16"], + "shapes": [[100, 100], [50]], + } + ) + + assert isinstance(update_info, NCCLWeightTransferUpdateInfo) + assert update_info.names == ["w1", "w2"] + assert update_info.dtype_names == ["float32", "bfloat16"] + assert update_info.shapes == [[100, 100], [50]] + + +# --- Unit Tests: Engine Registry --- + + +class TestEngineRegistry: + """Test weight transfer engine registry.""" + + def test_create_engine_nccl(self): + """Test factory creates NCCL engine.""" + config = WeightTransferConfig(backend="nccl") + parallel_config = create_mock_parallel_config() + engine = WeightTransferEngineFactory.create_engine(config, parallel_config) + assert isinstance(engine, NCCLWeightTransferEngine) + + def test_create_engine_invalid_backend(self): + """Test factory raises for invalid backend.""" + config = WeightTransferConfig(backend="invalid") + parallel_config = create_mock_parallel_config() + with pytest.raises(ValueError, match="Invalid weight transfer backend"): + WeightTransferEngineFactory.create_engine(config, parallel_config) + + def test_register_duplicate_raises(self): + """Test registering duplicate engine name raises.""" + with pytest.raises(ValueError, match="already registered"): + WeightTransferEngineFactory.register_engine( + "nccl", NCCLWeightTransferEngine + ) + + +# --- Test receive_weights without init raises --- + + +def test_nccl_receive_weights_without_init_raises(): + """Test that receive_weights raises if init_transfer_engine wasn't called.""" + if torch.cuda.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + config = WeightTransferConfig(backend="nccl") + parallel_config = create_mock_parallel_config() + engine = NCCLWeightTransferEngine(config, parallel_config) + + update_info = NCCLWeightTransferUpdateInfo( + names=["w"], + dtype_names=["float32"], + shapes=[[10]], + ) + + with pytest.raises(RuntimeError, match="not initialized"): + engine.receive_weights(update_info, lambda x: None) + + +# --- Integration Test: NCCL Weight Transfer Between Ray Tasks --- + + +@ray.remote(num_gpus=1) +def trainer_broadcast_tensor( + master_address: str, + master_port: int, + world_size: int, + tensor_shape: list[int], + tensor_dtype: str, +) -> bool: + """Trainer task that broadcasts a tensor via NCCL.""" + import torch + + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.utils import StatelessProcessGroup + + # Create process group as rank 0 (trainer) + pg = StatelessProcessGroup.create( + host=master_address, + port=master_port, + rank=0, + world_size=world_size, + ) + # Ray sets CUDA_VISIBLE_DEVICES, so device 0 is the assigned GPU + comm = PyNcclCommunicator(pg, device=0) + + # Create and broadcast the tensor + dtype = getattr(torch, tensor_dtype) + tensor_to_send = torch.ones(tensor_shape, dtype=dtype, device="cuda:0") + comm.broadcast(tensor_to_send, src=0, stream=torch.cuda.current_stream()) + torch.cuda.synchronize() + + return True + + +@ray.remote(num_gpus=1) +def inference_receive_tensor( + master_address: str, + master_port: int, + world_size: int, + tensor_shape: list[int], + tensor_dtype: str, +) -> dict: + """Inference task that receives tensor via NCCLWeightTransferEngine.""" + from unittest.mock import MagicMock + + import torch + + from vllm.config.parallel import ParallelConfig + from vllm.config.weight_transfer import WeightTransferConfig + from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLWeightTransferEngine, + NCCLWeightTransferInitInfo, + NCCLWeightTransferUpdateInfo, + ) + + # Create engine with mock parallel config + config = WeightTransferConfig(backend="nccl") + parallel_config = MagicMock(spec=ParallelConfig) + parallel_config.rank = 0 + parallel_config.world_size = 1 + parallel_config.data_parallel_rank = 0 + + engine = NCCLWeightTransferEngine(config, parallel_config) + + # Initialize the engine (joins as rank 1) + init_info = NCCLWeightTransferInitInfo( + master_address=master_address, + master_port=master_port, + rank_offset=1, # Trainer is rank 0, we become rank 1 + world_size=world_size, + ) + engine.init_transfer_engine(init_info) + + # Receive weights with a no-op load_weights that captures the tensor + received_tensors = [] + + def noop_load_weights(weights: list[tuple[str, torch.Tensor]]): + for name, tensor in weights: + # Clone tensor to keep it after engine cleans up + received_tensors.append((name, tensor.clone())) + + update_info = NCCLWeightTransferUpdateInfo( + names=["test.weight"], + dtype_names=[tensor_dtype], + shapes=[tensor_shape], + ) + engine.receive_weights(update_info, noop_load_weights) + torch.cuda.synchronize() + + # Verify we received the tensor + success = False + received_shape = None + received_sum = None + + if len(received_tensors) == 1: + name, tensor = received_tensors[0] + received_shape = list(tensor.shape) + received_sum = tensor.sum().item() + # Check shape matches and values are all 1s (trainer sends ones) + if received_shape == tensor_shape: + expected_sum = 1.0 * torch.tensor(tensor_shape).prod().item() + if abs(received_sum - expected_sum) < 0.01: + success = True + + engine.shutdown() + + return { + "success": success, + "received_shape": received_shape, + "received_sum": received_sum, + } + + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run NCCL weight transfer test.", +) +def test_nccl_weight_transfer_between_processes(): + """Test NCCL weight transfer from trainer to inference process using Ray. + + This test verifies that the NCCLWeightTransferEngine can receive + tensors broadcast by a trainer process via NCCL. + """ + ray.init(ignore_reinit_error=True) + + master_address = "127.0.0.1" + master_port = get_open_port() + world_size = 2 # 1 trainer + 1 inference worker + + # Tensor to transfer: 100x100 ones + tensor_shape = [100, 100] + tensor_dtype = "float32" + + # Start both tasks concurrently - Ray assigns GPUs automatically + inference_future = inference_receive_tensor.remote( + master_address, master_port, world_size, tensor_shape, tensor_dtype + ) + trainer_future = trainer_broadcast_tensor.remote( + master_address, master_port, world_size, tensor_shape, tensor_dtype + ) + + # Wait for both to complete + trainer_result, result = ray.get([trainer_future, inference_future]) + + assert trainer_result, "Trainer should complete successfully" + assert result["success"], ( + f"Weight transfer failed. " + f"Received shape: {result['received_shape']}, " + f"Received sum: {result['received_sum']}" + ) diff --git a/tests/entrypoints/openai/test_openai_schema.py b/tests/entrypoints/openai/test_openai_schema.py index 50d24a400549..1baab9934fdd 100644 --- a/tests/entrypoints/openai/test_openai_schema.py +++ b/tests/entrypoints/openai/test_openai_schema.py @@ -139,6 +139,14 @@ def test_openapi_stateless(case: schemathesis.Case): # Skip responses API as it is meant to be stateful. return + # Skip weight transfer endpoints as they require special setup + # (weight_transfer_config) and are meant to be stateful. + if case.operation.path in ( + "/init_weight_transfer_engine", + "/update_weights", + ): + return + timeout = { # requires a longer timeout ("POST", "/v1/chat/completions"): LONG_TIMEOUT_SECONDS, diff --git a/tests/entrypoints/weight_transfer/__init__.py b/tests/entrypoints/weight_transfer/__init__.py new file mode 100644 index 000000000000..6655f8913623 --- /dev/null +++ b/tests/entrypoints/weight_transfer/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + diff --git a/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py b/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py new file mode 100644 index 000000000000..9f2309c765b5 --- /dev/null +++ b/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py @@ -0,0 +1,300 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for weight transfer APIs via LLM class. + +These tests use a mock weight transfer engine to verify that the API +calls the correct methods with the right arguments, without requiring +actual NCCL communication. +""" + +import os +from collections.abc import Callable +from dataclasses import dataclass +from unittest.mock import patch + +import pytest +import torch + +from vllm import LLM +from vllm.config import WeightTransferConfig +from vllm.distributed.weight_transfer.base import ( + WeightTransferEngine, + WeightTransferInitInfo, + WeightTransferInitRequest, + WeightTransferUpdateInfo, + WeightTransferUpdateRequest, +) + +from ...utils import create_new_process_for_each_test + +# Use a tiny model for fast testing +MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM" + + +# --- Mock Weight Transfer Engine --- + + +@dataclass +class MockInitInfo(WeightTransferInitInfo): + """Mock initialization info.""" + + test_param: str = "test" + + +@dataclass +class MockUpdateInfo(WeightTransferUpdateInfo): + """Mock update info.""" + + names: list[str] | None = None + dtype_names: list[str] | None = None + shapes: list[list[int]] | None = None + + +class MockWeightTransferEngine(WeightTransferEngine[MockInitInfo, MockUpdateInfo]): + """Mock weight transfer engine that tracks method calls.""" + + init_info_cls = MockInitInfo + update_info_cls = MockUpdateInfo + + # Class-level tracking for verification across processes + init_transfer_engine_called: bool = False + receive_weights_called: bool = False + shutdown_called: bool = False + last_init_info: MockInitInfo | None = None + last_update_info: MockUpdateInfo | None = None + + def __init__(self, config, parallel_config): + super().__init__(config, parallel_config) + # Reset tracking on init + MockWeightTransferEngine.init_transfer_engine_called = False + MockWeightTransferEngine.receive_weights_called = False + MockWeightTransferEngine.shutdown_called = False + MockWeightTransferEngine.last_init_info = None + MockWeightTransferEngine.last_update_info = None + + def init_transfer_engine(self, init_info: MockInitInfo) -> None: + MockWeightTransferEngine.init_transfer_engine_called = True + MockWeightTransferEngine.last_init_info = init_info + + def receive_weights( + self, + update_info: MockUpdateInfo, + load_weights: Callable[[list[tuple[str, torch.Tensor]]], None], + ) -> None: + MockWeightTransferEngine.receive_weights_called = True + MockWeightTransferEngine.last_update_info = update_info + # Simulate loading weights by calling load_weights with empty list + # (In real implementation, this would receive and load actual weights) + load_weights([]) + + def shutdown(self) -> None: + MockWeightTransferEngine.shutdown_called = True + + +def mock_create_engine(config, parallel_config): + """Mock factory function that returns our mock engine.""" + return MockWeightTransferEngine(config, parallel_config) + + +# --- Tests --- + + +@create_new_process_for_each_test() +def test_get_world_size_tp1(): + """Test world_size is correctly configured for TP=1.""" + if torch.cuda.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + llm = LLM( + model=MODEL_NAME, + enforce_eager=True, + load_format="dummy", + tensor_parallel_size=1, + weight_transfer_config=WeightTransferConfig(backend="nccl"), + ) + + world_size = llm.llm_engine.vllm_config.parallel_config.world_size + assert world_size == 1 + + +@create_new_process_for_each_test() +def test_init_weight_transfer_engine_calls_engine(): + """Test that init_weight_transfer_engine calls the engine's + init_transfer_engine method.""" + if torch.cuda.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + # Enable insecure serialization to allow pickling functions for collective_rpc + os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" + + with patch( + "vllm.v1.worker.gpu_worker.WeightTransferEngineFactory.create_engine", + mock_create_engine, + ): + llm = LLM( + model=MODEL_NAME, + enforce_eager=True, + load_format="dummy", + tensor_parallel_size=1, + weight_transfer_config=WeightTransferConfig(backend="nccl"), + ) + + # Verify engine was created + def check_engine_exists(self): + return self.weight_transfer_engine is not None + + results = llm.collective_rpc(check_engine_exists) + assert all(results), "Weight transfer engine should be initialized" + + # Call init_weight_transfer_engine + llm.init_weight_transfer_engine( + WeightTransferInitRequest(init_info={"test_param": "hello"}) + ) + + # Verify init_transfer_engine was called on the engine + def check_init_called(self): + engine = self.weight_transfer_engine + return ( + engine.init_transfer_engine_called, + engine.last_init_info.test_param if engine.last_init_info else None, + ) + + results = llm.collective_rpc(check_init_called) + for called, param in results: + assert called, "init_transfer_engine should have been called" + assert param == "hello", f"Expected 'hello', got {param}" + + +@create_new_process_for_each_test() +def test_update_weights_calls_engine(): + """Test that update_weights calls the engine's receive_weights method.""" + if torch.cuda.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + # Enable insecure serialization to allow pickling functions for collective_rpc + os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" + + with patch( + "vllm.v1.worker.gpu_worker.WeightTransferEngineFactory.create_engine", + mock_create_engine, + ): + llm = LLM( + model=MODEL_NAME, + enforce_eager=True, + load_format="dummy", + tensor_parallel_size=1, + weight_transfer_config=WeightTransferConfig(backend="nccl"), + ) + + # First init the weight transfer + llm.init_weight_transfer_engine( + WeightTransferInitRequest(init_info={"test_param": "init"}) + ) + + # Call update_weights + test_names = ["layer.weight", "layer.bias"] + test_dtypes = ["float32", "float32"] + test_shapes = [[10, 10], [10]] + + llm.update_weights( + WeightTransferUpdateRequest( + update_info={ + "names": test_names, + "dtype_names": test_dtypes, + "shapes": test_shapes, + } + ) + ) + + # Verify receive_weights was called with correct info + def check_update_called(self): + engine = self.weight_transfer_engine + if not engine.receive_weights_called: + return False, None, None, None + info = engine.last_update_info + return (True, info.names, info.dtype_names, info.shapes) + + results = llm.collective_rpc(check_update_called) + for called, names, dtypes, shapes in results: + assert called, "receive_weights should have been called" + assert names == test_names + assert dtypes == test_dtypes + assert shapes == test_shapes + + +@create_new_process_for_each_test() +def test_full_weight_transfer_flow(): + """Test the complete weight transfer flow: init -> update.""" + if torch.cuda.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + # Enable insecure serialization to allow pickling functions for collective_rpc + os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" + + with patch( + "vllm.v1.worker.gpu_worker.WeightTransferEngineFactory.create_engine", + mock_create_engine, + ): + llm = LLM( + model=MODEL_NAME, + enforce_eager=True, + load_format="dummy", + tensor_parallel_size=1, + weight_transfer_config=WeightTransferConfig(backend="nccl"), + ) + + # Step 1: Initialize + llm.init_weight_transfer_engine( + WeightTransferInitRequest(init_info={"test_param": "flow_test"}) + ) + + # Step 2: Update weights + llm.update_weights( + WeightTransferUpdateRequest( + update_info={ + "names": ["test.weight"], + "dtype_names": ["bfloat16"], + "shapes": [[100, 100]], + } + ) + ) + + # Verify the full flow completed + def check_flow(self): + engine = self.weight_transfer_engine + return { + "init_called": engine.init_transfer_engine_called, + "update_called": engine.receive_weights_called, + "init_param": ( + engine.last_init_info.test_param if engine.last_init_info else None + ), + "update_names": ( + engine.last_update_info.names if engine.last_update_info else None + ), + } + + results = llm.collective_rpc(check_flow) + for result in results: + assert result["init_called"], "init_transfer_engine should be called" + assert result["update_called"], "receive_weights should be called" + assert result["init_param"] == "flow_test" + assert result["update_names"] == ["test.weight"] + + +@create_new_process_for_each_test() +def test_weight_transfer_config_backend(): + """Test that WeightTransferConfig backend is properly configured.""" + if torch.cuda.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + # Test with nccl backend + llm = LLM( + model=MODEL_NAME, + enforce_eager=True, + load_format="dummy", + tensor_parallel_size=1, + weight_transfer_config=WeightTransferConfig(backend="nccl"), + ) + + config = llm.llm_engine.vllm_config.weight_transfer_config + assert config.backend == "nccl" diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index b2044c6e1d03..6014f642c57d 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -47,6 +47,7 @@ get_layers_from_vllm_config, set_current_vllm_config, ) +from vllm.config.weight_transfer import WeightTransferConfig # __all__ should only contain classes and functions. # Types and globals should be imported from their respective modules. @@ -111,4 +112,5 @@ "get_current_vllm_config_or_none", "set_current_vllm_config", "get_layers_from_vllm_config", + "WeightTransferConfig", ] diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 846ed50e0bdd..dbd84785e36d 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -42,6 +42,7 @@ from .speculative import EagleModelTypes, SpeculativeConfig from .structured_outputs import StructuredOutputsConfig from .utils import SupportsHash, config, replace +from .weight_transfer import WeightTransferConfig if TYPE_CHECKING: from transformers import PretrainedConfig @@ -255,6 +256,9 @@ class VllmConfig: performance. -02 is used by defult. See OptimizationLevel for full description.""" + weight_transfer_config: WeightTransferConfig | None = None + """The configurations for weight transfer during RL training.""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/config/weight_transfer.py b/vllm/config/weight_transfer.py new file mode 100644 index 000000000000..7ccac13fbfaa --- /dev/null +++ b/vllm/config/weight_transfer.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Literal + +from vllm.config.utils import config + + +@config +@dataclass +class WeightTransferConfig: + """Configuration for weight transfer during RL training.""" + + backend: Literal["nccl"] = "nccl" + """The backend to use for weight transfer.""" diff --git a/vllm/distributed/weight_transfer/__init__.py b/vllm/distributed/weight_transfer/__init__.py new file mode 100644 index 000000000000..c96ad0e3bb4f --- /dev/null +++ b/vllm/distributed/weight_transfer/__init__.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Weight transfer engines for syncing model weights from trainers +to inference workers. +""" + +from vllm.distributed.weight_transfer.factory import WeightTransferEngineFactory + +__all__ = [ + "WeightTransferEngineFactory", +] diff --git a/vllm/distributed/weight_transfer/base.py b/vllm/distributed/weight_transfer/base.py new file mode 100644 index 000000000000..b87f190fcf7a --- /dev/null +++ b/vllm/distributed/weight_transfer/base.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Base class for weight transfer engines.""" + +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import KW_ONLY, dataclass, field +from typing import Any, Generic, TypeVar + +import torch + +from vllm.config.parallel import ParallelConfig +from vllm.config.weight_transfer import WeightTransferConfig + +TInitInfo = TypeVar("TInitInfo", bound="WeightTransferInitInfo") +TUpdateInfo = TypeVar("TUpdateInfo", bound="WeightTransferUpdateInfo") + + +# Base protocols for backend-specific dataclasses +@dataclass +class WeightTransferInitInfo(ABC): # noqa: B024 + """Base class for backend-specific initialization info.""" + + pass + + +@dataclass +class WeightTransferUpdateInfo(ABC): # noqa: B024 + """Base class for backend-specific weight update info.""" + + _: KW_ONLY + is_checkpoint_format: bool = True + """Set to True if weights are in checkpoint/original model format and need + layerwise processing. Set to False if weights have already been processed + into kernel format (repacking, renaming, etc.).""" + + +# API-level request classes (accept dicts for backend-agnostic serialization) +@dataclass +class WeightTransferInitRequest: + """API-level weight transfer initialization request.""" + + init_info: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class WeightTransferUpdateRequest: + """API-level weight update request.""" + + update_info: dict[str, Any] = field(default_factory=dict) + + +class WeightTransferEngine(ABC, Generic[TInitInfo, TUpdateInfo]): + """ + Base class for weight transfer engines that handle transport of model weights + from a trainer to inference workers. + + This abstraction separates weight transfer transport logic from the worker + implementation, allowing different backends (NCCL, CUDA IPC[TODO], RDMA[TODO]) to be + plugged in. + + Subclasses should define: + init_info_cls: Type of backend-specific initialization info + update_info_cls: Type of backend-specific update info + """ + + # Subclasses should override these class attributes + init_info_cls: type[TInitInfo] + update_info_cls: type[TUpdateInfo] + + def __init__( + self, config: WeightTransferConfig, parallel_config: ParallelConfig + ) -> None: + """ + Initialize the weight transfer engine. + + Args: + config: The configuration for the weight transfer engine + parallel_config: The configuration for the parallel setup + """ + self.config = config + self.parallel_config = parallel_config + + def parse_init_info(self, init_dict: dict[str, Any]) -> TInitInfo: + """ + Construct typed init info from dict with validation. + + Args: + init_dict: Dictionary containing backend-specific initialization parameters + + Returns: + Typed backend-specific init info dataclass + + Raises: + ValueError: If init_dict is invalid for this backend + """ + try: + return self.init_info_cls(**init_dict) + except TypeError as e: + raise ValueError( + f"Invalid init_info for {self.__class__.__name__}: {e}" + ) from e + + def parse_update_info(self, update_dict: dict[str, Any]) -> TUpdateInfo: + """ + Construct typed update info from dict with validation. + + Args: + update_dict: Dictionary containing backend-specific update parameters + + Returns: + Typed backend-specific update info dataclass + + Raises: + ValueError: If update_dict is invalid for this backend + """ + try: + return self.update_info_cls(**update_dict) + except TypeError as e: + raise ValueError( + f"Invalid update_info for {self.__class__.__name__}: {e}" + ) from e + + @abstractmethod + def init_transfer_engine(self, init_info: TInitInfo) -> None: + """ + Initialize the weight transfer mechanism. + This is called once at the beginning of training. + + Args: + init_info: Backend-specific initialization info + """ + raise NotImplementedError + + @abstractmethod + def receive_weights( + self, + update_info: TUpdateInfo, + load_weights: Callable[[list[tuple[str, torch.Tensor]]], None], + ) -> None: + """ + Receive weights from the trainer and load them incrementally. + + Args: + update_info: Backend-specific update info containing parameter metadata + and any backend-specific data + load_weights: Callable that loads weights into the model. Called + incrementally for each weight to avoid OOM. + """ + raise NotImplementedError + + @abstractmethod + def shutdown(self) -> None: + """ + Shutdown the weight transfer engine. + This should be called when the worker is shutting down. + """ + raise NotImplementedError diff --git a/vllm/distributed/weight_transfer/factory.py b/vllm/distributed/weight_transfer/factory.py new file mode 100644 index 000000000000..7235e30d1af6 --- /dev/null +++ b/vllm/distributed/weight_transfer/factory.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Factory for weight transfer engines with lazy loading.""" + +import importlib +from collections.abc import Callable +from typing import TYPE_CHECKING + +from vllm.distributed.weight_transfer.base import WeightTransferEngine +from vllm.logger import init_logger + +if TYPE_CHECKING: + from vllm.config.parallel import ParallelConfig + from vllm.config.weight_transfer import WeightTransferConfig + +logger = init_logger(__name__) + + +class WeightTransferEngineFactory: + """Factory for creating weight transfer engines with lazy loading. + + This factory implements a registry pattern that supports: + - Lazy loading: Engine modules are only imported when actually needed + - Extensibility: Custom engines can be registered at runtime + - Centralized registration: All built-in engines registered in one place + """ + + _registry: dict[str, Callable[[], type[WeightTransferEngine]]] = {} + + @classmethod + def register_engine( + cls, + name: str, + module_path_or_cls: str | type[WeightTransferEngine], + class_name: str | None = None, + ) -> None: + """Register an engine with lazy-loading or direct class reference. + + Supports two calling conventions: + 1. Lazy loading: register_engine(name, module_path, class_name) + 2. Direct class: register_engine(name, engine_cls) + + Args: + name: The name to register the engine under (e.g., "nccl") + module_path_or_cls: Either a module path string for lazy loading, + or the engine class directly + class_name: Name of the engine class (required if module_path is string) + + Raises: + ValueError: If an engine with the same name is already registered + """ + if name in cls._registry: + raise ValueError(f"Weight transfer engine '{name}' is already registered.") + + if isinstance(module_path_or_cls, str): + # Lazy loading path + module_path = module_path_or_cls + if class_name is None: + raise ValueError( + "class_name is required when registering with module path" + ) + + def loader() -> type[WeightTransferEngine]: + module = importlib.import_module(module_path) + return getattr(module, class_name) + + cls._registry[name] = loader + else: + # Direct class registration + engine_cls = module_path_or_cls + cls._registry[name] = lambda: engine_cls + + @classmethod + def create_engine( + cls, + config: "WeightTransferConfig", + parallel_config: "ParallelConfig", + ) -> WeightTransferEngine: + """Create a weight transfer engine instance. + + Args: + config: Weight transfer configuration containing the backend name + parallel_config: Parallel configuration for the engine + + Returns: + An initialized weight transfer engine instance + + Raises: + ValueError: If the backend is not registered + """ + backend = config.backend + if backend not in cls._registry: + available = list(cls._registry.keys()) + raise ValueError( + f"Invalid weight transfer backend: {backend}. " + f"Available engines: {available}" + ) + engine_cls = cls._registry[backend]() + + logger.info( + "Creating weight transfer engine: %s", + engine_cls.__name__, + ) + + return engine_cls(config, parallel_config) + + +# Register built-in weight transfer engines here. +# Registration should be centralized to ensure lazy loading - +# engine modules are only imported when actually used. + +WeightTransferEngineFactory.register_engine( + "nccl", + "vllm.distributed.weight_transfer.nccl_engine", + "NCCLWeightTransferEngine", +) diff --git a/vllm/distributed/weight_transfer/nccl_engine.py b/vllm/distributed/weight_transfer/nccl_engine.py new file mode 100644 index 000000000000..5c90198bf616 --- /dev/null +++ b/vllm/distributed/weight_transfer/nccl_engine.py @@ -0,0 +1,315 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""NCCL-based weight transfer engine.""" + +from collections.abc import Callable, Iterator +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import torch + +if TYPE_CHECKING: + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + +from vllm.config.parallel import ParallelConfig +from vllm.config.weight_transfer import WeightTransferConfig +from vllm.distributed.weight_transfer.base import ( + WeightTransferEngine, + WeightTransferInitInfo, + WeightTransferUpdateInfo, +) +from vllm.distributed.weight_transfer.packed_tensor import ( + DEFAULT_PACKED_BUFFER_SIZE_BYTES, + DEFAULT_PACKED_NUM_BUFFERS, + packed_broadcast_consumer, +) + + +@dataclass +class NCCLWeightTransferInitInfo(WeightTransferInitInfo): + """Initialization info for NCCL weight transfer backend.""" + + master_address: str + master_port: int + rank_offset: int + world_size: int + + +@dataclass +class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo): + """Update info for NCCL weight transfer backend.""" + + names: list[str] + dtype_names: list[str] + shapes: list[list[int]] + packed: bool = False + """Whether to use packed tensor broadcasting for efficiency. + When True, multiple tensors are batched together before broadcasting + to reduce NCCL communication overhead.""" + packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES + """Size in bytes for each packed tensor buffer. Default is 1GB. + Both producer and consumer must use the same value.""" + packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS + """Number of buffers for double/triple buffering during packed transfer. + Both producer and consumer must use the same value.""" + + def __post_init__(self): + """Validate that all lists have the same length.""" + num_params = len(self.names) + if len(self.dtype_names) != num_params: + raise ValueError( + f"`dtype_names` should be of the same size as `names`: " + f"got {len(self.dtype_names)} and {len(self.names)}" + ) + if len(self.shapes) != num_params: + raise ValueError( + f"`shapes` should be of the same size as `names`: " + f"got {len(self.shapes)} and {len(self.names)}" + ) + + +class NCCLWeightTransferEngine( + WeightTransferEngine[NCCLWeightTransferInitInfo, NCCLWeightTransferUpdateInfo] +): + """ + Weight transfer engine using NCCL for communication between trainer and workers. + + This implementation uses NCCL broadcast operations to transfer weights from + the trainer (rank 0) to all inference workers in a process group. + """ + + # Define backend-specific dataclass types + init_info_cls = NCCLWeightTransferInitInfo + update_info_cls = NCCLWeightTransferUpdateInfo + + def __init__( + self, config: WeightTransferConfig, parallel_config: ParallelConfig + ) -> None: + """ + Initialize the NCCL weight transfer engine. + + Args: + config: The configuration for the weight transfer engine + parallel_config: The configuration for the parallel setup + """ + super().__init__(config, parallel_config) + self.model_update_group: PyNcclCommunicator | None = None + + def init_transfer_engine(self, init_info: NCCLWeightTransferInitInfo) -> None: + """ + Initialize NCCL process group with the trainer. + + Args: + init_info: NCCL initialization info containing master address, port, + rank offset, and world size + """ + + # Calculate the global rank in the trainer-worker process group + # Must account for data parallel to get unique ranks across all workers + dp_rank = self.parallel_config.data_parallel_rank + world_size_per_dp = self.parallel_config.world_size # TP * PP + rank_within_dp = self.parallel_config.rank + + # Unique rank across all DP groups + worker_rank = dp_rank * world_size_per_dp + rank_within_dp + rank = worker_rank + init_info.rank_offset + # Create stateless process group + self.model_update_group = ( + NCCLWeightTransferEngine._stateless_init_process_group( + init_info.master_address, + init_info.master_port, + rank, + init_info.world_size, + torch.cuda.current_device(), + ) + ) + + def receive_weights( + self, + update_info: NCCLWeightTransferUpdateInfo, + load_weights: Callable[[list[tuple[str, torch.Tensor]]], None], + ) -> None: + """ + Receive weights from trainer via NCCL broadcast and load them incrementally. + + If update_info.packed is True, uses packed tensor broadcasting for + efficient transfer of multiple weights in batches. Otherwise, uses simple + one-by-one broadcasting. + + Args: + update_info: NCCL update info containing parameter names, dtypes, shapes, + and packed flag + load_weights: Callable that loads weights into the model. Called + incrementally for each batch of weights to avoid OOM. + """ + if self.model_update_group is None: + raise RuntimeError( + "NCCL weight transfer not initialized. " + "Call init_transfer_engine() first." + ) + + if update_info.packed: + # Build iterator of (name, (shape, dtype)) from update_info + def state_dict_info_iterator(): + for name, dtype_name, shape in zip( + update_info.names, update_info.dtype_names, update_info.shapes + ): + dtype = getattr(torch, dtype_name) + yield (name, (shape, dtype)) + + packed_broadcast_consumer( + iterator=state_dict_info_iterator(), + group=self.model_update_group, + src=0, + post_unpack_func=load_weights, + buffer_size_bytes=update_info.packed_buffer_size_bytes, + num_buffers=update_info.packed_num_buffers, + ) + else: + # Use simple one-by-one broadcasting + for name, dtype_name, shape in zip( + update_info.names, update_info.dtype_names, update_info.shapes + ): + dtype = getattr(torch, dtype_name) + weight = torch.empty(shape, dtype=dtype, device="cuda") + self.model_update_group.broadcast( + weight, src=0, stream=torch.cuda.current_stream() + ) + load_weights([(name, weight)]) + del weight + + def shutdown(self) -> None: + if self.model_update_group is not None: + # Clean up the communicator by removing the reference + self.model_update_group = None + + @staticmethod + def trainer_send_weights( + iterator: Iterator[tuple[str, torch.Tensor]], + group: Any, + src: int = 0, + post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor] + | None = None, + packed: bool = False, + stream: torch.cuda.Stream | None = None, + packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES, + packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS, + ) -> None: + """Broadcast weights from trainer to vLLM workers. + + Args: + iterator: Iterator of model parameters. Returns (name, tensor) tuples + group: Process group (PyNcclCommunicator) + src: Source rank (default 0, trainer is typically rank 0) + post_iter_func: Optional function to apply to each (name, tensor) pair + before broadcasting. If None, extracts just the tensor. + packed: Whether to use packed tensor broadcasting for efficiency. + When True, multiple tensors are batched together before + broadcasting to reduce NCCL communication overhead. + stream: CUDA stream to use for broadcasting if packed is False. + If packed is True, new streams will be created for each buffer. + packed_buffer_size_bytes: Size in bytes for each packed tensor buffer. + Must match the value used in NCCLWeightTransferUpdateInfo. + packed_num_buffers: Number of buffers for double/triple buffering. + Must match the value used in NCCLWeightTransferUpdateInfo. + + Example: + >>> from vllm.distributed.weight_transfer.nccl_engine import ( + ... NCCLWeightTransferEngine, + ... ) + >>> param_iter = ((n, p) for n, p in model.named_parameters()) + >>> NCCLWeightTransferEngine.trainer_send_weights( + ... param_iter, group, packed=True + ... ) + """ + if post_iter_func is None: + # Default: extract just the tensor from (name, tensor) tuple + post_iter_func = lambda x: x[1] + + if packed: + # Use packed tensor broadcasting for efficiency + from vllm.distributed.weight_transfer.packed_tensor import ( + packed_broadcast_producer, + ) + + packed_broadcast_producer( + iterator=iterator, + group=group, + src=src, + post_iter_func=post_iter_func, + buffer_size_bytes=packed_buffer_size_bytes, + num_buffers=packed_num_buffers, + ) + else: + # Use simple one-by-one broadcasting + for item in iterator: + tensor = post_iter_func(item) + group.broadcast( + tensor, src=src, stream=stream or torch.cuda.current_stream() + ) + + @staticmethod + def trainer_init( + init_info: NCCLWeightTransferInitInfo | dict, + ) -> "PyNcclCommunicator": + """ + Initialize NCCL process group for trainer-side weight transfer. + + The trainer is always rank 0 in the process group. Uses the current + CUDA device (torch.cuda.current_device()). + + Args: + init_info: Either an NCCLWeightTransferInitInfo object or a dict with keys: + - master_address: str + - master_port: int + - world_size: int + + Returns: + PyNcclCommunicator for weight transfer. + + Example: + >>> from vllm.distributed.weight_transfer.nccl_engine import ( + ... NCCLWeightTransferEngine, + ... ) + >>> group = NCCLWeightTransferEngine.trainer_init( + ... dict( + ... master_address=master_address, + ... master_port=master_port, + ... world_size=world_size, + ... ), + ... ) + """ + if isinstance(init_info, dict): + master_address = init_info["master_address"] + master_port = init_info["master_port"] + world_size = init_info["world_size"] + else: + # NCCLWeightTransferInitInfo object + master_address = init_info.master_address + master_port = init_info.master_port + world_size = init_info.world_size + + # Trainer is always rank 0 + return NCCLWeightTransferEngine._stateless_init_process_group( + master_address, master_port, 0, world_size, torch.cuda.current_device() + ) + + @staticmethod + def _stateless_init_process_group( + master_address, master_port, rank, world_size, device + ): + """ + vLLM provides `StatelessProcessGroup` to create a process group + without considering the global process group in torch.distributed. + It is recommended to create `StatelessProcessGroup`, and then initialize + the data-plane communication (NCCL) between external (train processes) + and vLLM workers. + """ + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.utils import StatelessProcessGroup + + pg = StatelessProcessGroup.create( + host=master_address, port=master_port, rank=rank, world_size=world_size + ) + pynccl = PyNcclCommunicator(pg, device=device) + return pynccl diff --git a/vllm/distributed/weight_transfer/packed_tensor.py b/vllm/distributed/weight_transfer/packed_tensor.py new file mode 100644 index 000000000000..1c96d72edac7 --- /dev/null +++ b/vllm/distributed/weight_transfer/packed_tensor.py @@ -0,0 +1,216 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Packed tensor utilities for efficient weight transfer.""" + +import math +from collections.abc import Callable, Iterator +from typing import Any + +import torch + +# Default values for packed tensor configuration. +# These are imported by NCCLWeightTransferUpdateInfo and trainer_send_weights. +DEFAULT_PACKED_BUFFER_SIZE_BYTES = 1024 * 1024 * 1024 # 1GB +DEFAULT_PACKED_NUM_BUFFERS = 2 + + +def packed_broadcast_producer( + iterator: Iterator[tuple[str, torch.Tensor]], + group: Any, + src: int, + post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor], + buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES, + num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS, +) -> None: + """Broadcast tensors in a packed manner from trainer to workers. + + Args: + iterator: Iterator of model parameters. Returns a tuple of (name, tensor) + group: Process group (PyNcclCommunicator) + src: Source rank (0 in current implementation) + post_iter_func: Function to apply to each (name, tensor) pair before + packing, should return a tensor + buffer_size_bytes: Size in bytes for each packed tensor buffer. + Both producer and consumer must use the same value. + num_buffers: Number of buffers for double/triple buffering. + Both producer and consumer must use the same value. + + """ + target_packed_tensor_size = buffer_size_bytes + + streams = [torch.cuda.Stream() for _ in range(num_buffers)] + buffer_idx = 0 + + packing_tensor_list: list[list[torch.Tensor]] = [[] for _ in range(num_buffers)] + packing_tensor_sizes: list[int] = [0 for _ in range(num_buffers)] + packed_tensors: list[torch.Tensor] = [ + torch.empty(0, dtype=torch.uint8, device="cuda") for _ in range(num_buffers) + ] + + while True: + # Synchronize the current stream + streams[buffer_idx].synchronize() + # Start tasks for the new buffer in a new stream + with torch.cuda.stream(streams[buffer_idx]): + try: + # Initialize the packing tensor list and sizes + packing_tensor_list[buffer_idx] = [] + packing_tensor_sizes[buffer_idx] = 0 + # Pack the tensors + while True: + # Apply post processing and convert to linearized uint8 tensor + tensor = ( + post_iter_func(next(iterator)) + .contiguous() + .view(torch.uint8) + .view(-1) + ) + packing_tensor_list[buffer_idx].append(tensor) + packing_tensor_sizes[buffer_idx] += tensor.numel() + if packing_tensor_sizes[buffer_idx] > target_packed_tensor_size: + break + # Pack the tensors and call broadcast collective + packed_tensors[buffer_idx] = torch.cat( + packing_tensor_list[buffer_idx], dim=0 + ) + group.broadcast(packed_tensors[buffer_idx], src=src) + # Move to the next buffer + buffer_idx = (buffer_idx + 1) % num_buffers + except StopIteration: + # Do the last broadcast if there are remaining tensors + if len(packing_tensor_list[buffer_idx]) > 0: + packed_tensors[buffer_idx] = torch.cat( + packing_tensor_list[buffer_idx], dim=0 + ) + group.broadcast(packed_tensors[buffer_idx], src=src) + break + + +def packed_broadcast_consumer( + iterator: Iterator[tuple[str, tuple[list[int], torch.dtype]]], + group: Any, + src: int, + post_unpack_func: Callable[[list[tuple[str, torch.Tensor]]], None], + buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES, + num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS, +) -> None: + """Consume packed tensors and unpack them into a list of tensors. + + Args: + iterator: Iterator of parameter metadata. Returns (name, (shape, dtype)) + group: Process group (PyNcclCommunicator) + src: Source rank (0 in current implementation) + post_unpack_func: Function to apply to each list of (name, tensor) after + unpacking + buffer_size_bytes: Size in bytes for each packed tensor buffer. + Both producer and consumer must use the same value. + num_buffers: Number of buffers for double/triple buffering. + Both producer and consumer must use the same value. + + """ + + def unpack_tensor( + packed_tensor: torch.Tensor, + names: list[str], + shapes: list[list[int]], + dtypes: list[torch.dtype], + tensor_sizes: list[int], + ) -> list[tuple[str, torch.Tensor]]: + """Unpack a single tensor into a list of tensors. + + Args: + packed_tensor: The packed torch.uint8 tensor to unpack + names: List of tensor names + shapes: List of tensor shapes + dtypes: List of tensor dtypes + tensor_sizes: List of tensor sizes in bytes + + Returns: + unpacked List[(name, tensor)] + """ + unpacked_tensors = packed_tensor.split(tensor_sizes) + + unpacked_list = [ + (name, tensor.contiguous().view(dtype).view(*shape)) + for name, shape, dtype, tensor in zip( + names, shapes, dtypes, unpacked_tensors + ) + ] + + return unpacked_list + + target_packed_tensor_size = buffer_size_bytes + + streams = [torch.cuda.Stream() for _ in range(num_buffers)] + buffer_idx = 0 + + packing_tensor_meta_data: list[list[tuple[str, list[int], torch.dtype, int]]] = [ + [] for _ in range(num_buffers) + ] + packing_tensor_sizes: list[int] = [0 for _ in range(num_buffers)] + packed_tensors: list[torch.Tensor] = [ + torch.empty(0, dtype=torch.uint8, device="cuda") for _ in range(num_buffers) + ] + + while True: + # Synchronize the current stream + streams[buffer_idx].synchronize() + with torch.cuda.stream(streams[buffer_idx]): + # Initialize the packing tensor meta data + packing_tensor_meta_data[buffer_idx] = [] + packing_tensor_sizes[buffer_idx] = 0 + try: + # Form a packed tensor + while True: + name, (shape, dtype) = next(iterator) + tensor_size = math.prod(shape) * dtype.itemsize + packing_tensor_meta_data[buffer_idx].append( + (name, shape, dtype, tensor_size) + ) + packing_tensor_sizes[buffer_idx] += tensor_size + if packing_tensor_sizes[buffer_idx] > target_packed_tensor_size: + break + # Create a packed tensor and broadcast it + packed_tensors[buffer_idx] = torch.empty( + packing_tensor_sizes[buffer_idx], dtype=torch.uint8, device="cuda" + ) + group.broadcast(packed_tensors[buffer_idx], src=src) + # Load the packed tensor into the model + names, shapes, dtypes, tensor_sizes = zip( + *packing_tensor_meta_data[buffer_idx] + ) + post_unpack_func( + unpack_tensor( + packed_tensors[buffer_idx], + list(names), + list(shapes), + list(dtypes), + list(tensor_sizes), + ) + ) + # Move to the next buffer + buffer_idx = (buffer_idx + 1) % num_buffers + except StopIteration: + # Do the last broadcast if there are remaining tensors + if len(packing_tensor_meta_data[buffer_idx]) > 0: + # Create a packed tensor and broadcast it + packed_tensors[buffer_idx] = torch.empty( + packing_tensor_sizes[buffer_idx], + dtype=torch.uint8, + device="cuda", + ) + group.broadcast(packed_tensors[buffer_idx], src=src) + # Load the packed tensor into the model + names, shapes, dtypes, tensor_sizes = zip( + *packing_tensor_meta_data[buffer_idx] + ) + post_unpack_func( + unpack_tensor( + packed_tensors[buffer_idx], + list(names), + list(shapes), + list(dtypes), + list(tensor_sizes), + ) + ) + break diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f3e7729f64e3..471516e32370 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -54,6 +54,7 @@ SpeculativeConfig, StructuredOutputsConfig, VllmConfig, + WeightTransferConfig, get_attr_docs, ) from vllm.config.cache import ( @@ -581,6 +582,11 @@ class EngineArgs: kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend tokens_only: bool = False + weight_transfer_config: WeightTransferConfig | None = None + """Configuration for weight transfer during RL training. + Accepts a JSON string or dict with backend-specific options. + Example: '{"backend": "nccl"}'""" + def __post_init__(self): # support `EngineArgs(compilation_config={...})` # without having to manually construct a @@ -591,6 +597,10 @@ def __post_init__(self): self.attention_config = AttentionConfig(**self.attention_config) if isinstance(self.eplb_config, dict): self.eplb_config = EPLBConfig(**self.eplb_config) + if isinstance(self.weight_transfer_config, dict): + self.weight_transfer_config = WeightTransferConfig( + **self.weight_transfer_config + ) # Setup plugins from vllm.plugins import load_general_plugins @@ -1189,6 +1199,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: vllm_group.add_argument( "--optimization-level", **vllm_kwargs["optimization_level"] ) + vllm_group.add_argument( + "--weight-transfer-config", **vllm_kwargs["weight_transfer_config"] + ) # Other arguments parser.add_argument( @@ -1765,6 +1778,7 @@ def create_engine_config( profiler_config=self.profiler_config, additional_config=self.additional_config, optimization_level=self.optimization_level, + weight_transfer_config=self.weight_transfer_config, ) return config diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 1502bbff4bfc..253cfc42d670 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -6,6 +6,10 @@ from typing import Any from vllm.config import ModelConfig, VllmConfig +from vllm.distributed.weight_transfer.base import ( + WeightTransferInitRequest, + WeightTransferUpdateRequest, +) from vllm.inputs.data import PromptType, StreamingInput from vllm.lora.request import LoRARequest from vllm.outputs import PoolingRequestOutput, RequestOutput @@ -191,3 +195,13 @@ async def collective_rpc( async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: """Get supported tasks""" raise NotImplementedError + + async def init_weight_transfer_engine( + self, init_request: WeightTransferInitRequest + ) -> None: + """Initialize weight transfer for RL training.""" + raise NotImplementedError + + async def update_weights(self, request: WeightTransferUpdateRequest) -> None: + """Batched weight update for RL training.""" + raise NotImplementedError diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 24545de19cba..82078cefa062 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -34,6 +34,10 @@ RunnerOption, TokenizerMode, ) +from vllm.distributed.weight_transfer.base import ( + WeightTransferInitRequest, + WeightTransferUpdateRequest, +) from vllm.engine.arg_utils import EngineArgs from vllm.entrypoints.chat_utils import ( ChatCompletionMessageParam, @@ -359,6 +363,23 @@ def _make_config(value: Any, cls: type[_R]) -> _R: def get_tokenizer(self) -> TokenizerLike: return self.llm_engine.get_tokenizer() + def get_world_size(self, include_dp: bool = True) -> int: + """Get the world size from the parallel config. + + Args: + include_dp: If True (default), returns the world size including + data parallelism (TP * PP * DP). If False, returns the world + size without data parallelism (TP * PP). + + Returns: + The world size (tensor_parallel_size * pipeline_parallel_size), + optionally multiplied by data_parallel_size if include_dp is True. + """ + parallel_config = self.llm_engine.vllm_config.parallel_config + if include_dp: + return parallel_config.world_size_across_dp + return parallel_config.world_size + def reset_mm_cache(self) -> None: self.input_processor.clear_mm_cache() self.llm_engine.reset_mm_cache() @@ -1805,6 +1826,38 @@ def _run_engine( # its previous requests. return sorted(outputs, key=lambda x: int(x.request_id)) + def init_weight_transfer_engine( + self, request: WeightTransferInitRequest | dict + ) -> None: + """ + Initialize weight transfer for RL training. + + Args: + request: Weight transfer initialization request with backend-specific info + """ + init_info_dict = ( + request["init_info"] if isinstance(request, dict) else request.init_info + ) + + self.llm_engine.collective_rpc( + "init_weight_transfer_engine", kwargs={"init_info": init_info_dict} + ) + + def update_weights(self, request: WeightTransferUpdateRequest | dict) -> None: + """ + Update the weights of the model. + + Args: + request: Weight update request with backend-specific update info + """ + update_info_dict = ( + request["update_info"] if isinstance(request, dict) else request.update_info + ) + + self.llm_engine.collective_rpc( + "update_weights", kwargs={"update_info": update_info_dict} + ) + def __repr__(self) -> str: """Return a transformers-style hierarchical view of the model.""" # Cache the result to avoid repeated collective_rpc calls diff --git a/vllm/entrypoints/serve/rlhf/api_router.py b/vllm/entrypoints/serve/rlhf/api_router.py index 3b37840ae089..38461b147781 100644 --- a/vllm/entrypoints/serve/rlhf/api_router.py +++ b/vllm/entrypoints/serve/rlhf/api_router.py @@ -1,12 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import json from http import HTTPStatus -from fastapi import APIRouter, FastAPI, Query, Request +from fastapi import APIRouter, FastAPI, HTTPException, Query, Request from fastapi.responses import JSONResponse +import vllm.envs as envs +from vllm.distributed.weight_transfer.base import ( + WeightTransferInitRequest, + WeightTransferUpdateRequest, +) from vllm.engine.protocol import EngineClient from vllm.logger import init_logger @@ -98,5 +103,63 @@ async def is_paused(raw_request: Request) -> JSONResponse: return JSONResponse(content={"is_paused": paused}) +@router.post("/init_weight_transfer_engine") +async def init_weight_transfer_engine(raw_request: Request): + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904 + init_info = body.get("init_info") + if init_info is None: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail="Missing 'init_info' in request body", + ) + await engine_client(raw_request).init_weight_transfer_engine( + WeightTransferInitRequest(init_info=init_info) + ) + return JSONResponse(content={"message": "Weight transfer initialized"}) + + +@router.post("/update_weights") +async def update_weights(raw_request: Request): + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904 + update_info = body.get("update_info") + if update_info is None: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail="Missing 'update_info' in request body", + ) + await engine_client(raw_request).update_weights( + request=WeightTransferUpdateRequest(update_info=update_info) + ) + return JSONResponse(content={"message": "Weights updated"}) + + +@router.get("/get_world_size") +async def get_world_size( + raw_request: Request, + include_dp: bool = Query(True), +): + """Get the world size from the parallel config. + + Args: + include_dp: If True (default), returns the world size including + data parallelism (TP * PP * DP). If False, returns the world + size without data parallelism (TP * PP). + """ + parallel_config = engine_client(raw_request).vllm_config.parallel_config + if include_dp: + world_size = parallel_config.world_size_across_dp + else: + world_size = parallel_config.world_size + return JSONResponse(content={"world_size": world_size}) + + def attach_router(app: FastAPI): + if not envs.VLLM_SERVER_DEV_MODE: + return app.include_router(router) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 53bdb972bcdb..8b9fe0f3e932 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -649,6 +649,9 @@ def process_weights_after_loading(self, layer: Module) -> None: ) # Activations not quantized for marlin. + # Prevent duplicate processing (e.g., during weight reload) + layer._already_called_process_weights_after_loading = True + class Fp8MoEMethod(FusedMoEMethodBase): """MoE method for FP8. @@ -908,6 +911,9 @@ def process_weights_after_loading(self, layer: Module) -> None: layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale ) + # Prevent duplicate processing (e.g., during weight reload) + layer._already_called_process_weights_after_loading = True + def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, @@ -1241,6 +1247,9 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w2_input_scale, ) + # Prevent duplicate processing (e.g., during weight reload) + layer._already_called_process_weights_after_loading = True + class Fp8KVCacheMethod(BaseKVCacheMethod): """ diff --git a/vllm/model_executor/model_loader/reload/layerwise.py b/vllm/model_executor/model_loader/reload/layerwise.py index f7aaf8a677c4..21795e63995e 100644 --- a/vllm/model_executor/model_loader/reload/layerwise.py +++ b/vllm/model_executor/model_loader/reload/layerwise.py @@ -216,6 +216,11 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo): # Materialize layer tensors onto device materialize_layer(layer) + # Reset FP8 online quantization flag so process_weights_after_loading + # will run again during reload + if hasattr(layer, "_already_called_process_weights_after_loading"): + delattr(layer, "_already_called_process_weights_after_loading") + # Unwrap layerwise loading wrappers for param in get_layer_tensors(layer).values(): param.weight_loader = _get_original_loader(param) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index f1a3e341fd99..26935321438a 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -14,6 +14,10 @@ import vllm.envs as envs from vllm import TokensPrompt from vllm.config import VllmConfig +from vllm.distributed.weight_transfer.base import ( + WeightTransferInitRequest, + WeightTransferUpdateRequest, +) from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import EngineClient from vllm.inputs import PromptType @@ -1006,3 +1010,44 @@ def errored(self) -> bool: @property def dead_error(self) -> BaseException: return EngineDeadError() + + async def init_weight_transfer_engine( + self, request: WeightTransferInitRequest + ) -> None: + """ + Initialize weight transfer for RL training. + + Args: + request: Weight transfer initialization request with backend-specific info + """ + from vllm.distributed.weight_transfer.base import ( + WeightTransferInitRequest, + ) + + if isinstance(request, WeightTransferInitRequest): + init_info_dict = request.init_info + else: + raise TypeError(f"Expected WeightTransferInitRequest, got {type(request)}") + + await self.collective_rpc( + "init_weight_transfer_engine", kwargs={"init_info": init_info_dict} + ) + + async def update_weights(self, request: WeightTransferUpdateRequest) -> None: + """ + Batched weight update for RL training. + + Args: + request: Weight update request with backend-specific update info + """ + + if isinstance(request, WeightTransferUpdateRequest): + update_info_dict = request.update_info + else: + raise TypeError( + f"Expected WeightTransferUpdateRequest, got {type(request)}" + ) + + await self.collective_rpc( + "update_weights", kwargs={"update_info": update_info_dict} + ) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index b451db3826f0..09880f79bf14 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -33,6 +33,7 @@ get_pp_group, get_tp_group, ) +from vllm.distributed.weight_transfer import WeightTransferEngineFactory from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.models.interfaces import is_mixture_of_experts @@ -89,6 +90,16 @@ def __init__( # Buffers saved before sleep self._sleep_saved_buffers: dict[str, torch.Tensor] = {} + # Weight transfer engine (initialized on-demand) + self.weight_transfer_engine = ( + WeightTransferEngineFactory.create_engine( + self.vllm_config.weight_transfer_config, + self.vllm_config.parallel_config, + ) + if self.vllm_config.weight_transfer_config is not None + else None + ) + # Torch/CUDA profiler. Enabled and configured through profiler_config. self.profiler: Any | None = None profiler_config = vllm_config.profiler_config @@ -932,6 +943,69 @@ def save_tensorized_model( tensorizer_config=tensorizer_config, ) + def init_weight_transfer_engine(self, init_info: dict) -> None: + """ + Initialize weight transfer mechanism. + For NCCL backend, this creates a process group with the trainer. + + Args: + init_info: Dictionary containing backend-specific initialization info + """ + if self.weight_transfer_engine is None: + raise RuntimeError( + "Weight transfer not configured. " + "Please set weight_transfer_config to enable weight transfer." + ) + # Parse dict into backend-specific typed dataclass + typed_init_info = self.weight_transfer_engine.parse_init_info(init_info) + self.weight_transfer_engine.init_transfer_engine(typed_init_info) + + def update_weights(self, update_info: dict) -> None: + """ + Batched weight update from the trainer. + + Args: + update_info: Dictionary containing backend-specific update info + """ + if self.weight_transfer_engine is None: + raise RuntimeError( + "Weight transfer not configured. " + "Please set weight_transfer_config to enable weight transfer." + ) + + # Parse dict into backend-specific typed dataclass + typed_update_info = self.weight_transfer_engine.parse_update_info(update_info) + + model = self.model_runner.model + + if typed_update_info.is_checkpoint_format: + from vllm.model_executor.model_loader.reload import ( + finalize_layerwise_reload, + initialize_layerwise_reload, + ) + + # Use layerwise reload pattern for checkpoint format weights + with torch.device(self.device): + initialize_layerwise_reload(model) + self.weight_transfer_engine.receive_weights( + typed_update_info, + load_weights=model.load_weights, + ) + finalize_layerwise_reload(model, self.model_config) + else: + # Weights are already in kernel format, copy directly + def load_weights_direct( + weights: list[tuple[str, torch.Tensor]], + ) -> None: + for name, weight in weights: + param = model.get_parameter(name) + param.copy_(weight) + + self.weight_transfer_engine.receive_weights( + typed_update_info, + load_weights=load_weights_direct, + ) + def shutdown(self) -> None: # has_kv_transfer_group can be None during interpreter shutdown. if ensure_kv_transfer_shutdown is not None: @@ -939,6 +1013,9 @@ def shutdown(self) -> None: if self.profiler is not None: self.profiler.shutdown() + if weight_transfer_engine := getattr(self, "weight_transfer_engine", None): + weight_transfer_engine.shutdown() + def init_worker_distributed_environment( vllm_config: VllmConfig,