From 21ef3d4ac914f73b8249bf769b400eaa6e37330a Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Wed, 7 Jan 2026 16:35:53 -0800 Subject: [PATCH 01/36] Weight transfer feature: incremental loading, async APIs, IPC/NCCL engines, DP fixes Squashed from 11 commits including: - weight transfer init - ipc support - dataclasses for weight transfer config - async + new weight update APIs - incremental weight loading, dp fixes, http example Co-authored-by: SumanthRH --- .gitignore | 4 +- .../new_weight_syncing/rlhf.py | 191 ++++++++++++++ .../new_weight_syncing/rlhf_async_new_apis.py | 248 ++++++++++++++++++ .../new_weight_syncing/rlhf_http.py | 224 ++++++++++++++++ .../new_weight_syncing/rlhf_ipc.py | 191 ++++++++++++++ .../new_weight_syncing/rlhf_utils.py | 20 ++ vllm/config/__init__.py | 2 + vllm/config/vllm.py | 6 + vllm/config/weight_transfer.py | 15 ++ vllm/distributed/weight_transfer/__init__.py | 50 ++++ vllm/distributed/weight_transfer/base.py | 154 +++++++++++ .../distributed/weight_transfer/ipc_engine.py | 127 +++++++++ .../weight_transfer/nccl_engine.py | 150 +++++++++++ vllm/engine/arg_utils.py | 22 +- vllm/engine/protocol.py | 18 ++ vllm/entrypoints/llm.py | 45 ++++ vllm/entrypoints/openai/api_server.py | 44 ++++ vllm/v1/engine/async_llm.py | 47 ++++ vllm/v1/worker/gpu_worker.py | 63 +++++ 19 files changed, 1611 insertions(+), 10 deletions(-) create mode 100644 examples/offline_inference/new_weight_syncing/rlhf.py create mode 100644 examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py create mode 100644 examples/offline_inference/new_weight_syncing/rlhf_http.py create mode 100644 examples/offline_inference/new_weight_syncing/rlhf_ipc.py create mode 100644 examples/offline_inference/new_weight_syncing/rlhf_utils.py create mode 100644 vllm/config/weight_transfer.py create mode 100644 vllm/distributed/weight_transfer/__init__.py create mode 100644 vllm/distributed/weight_transfer/base.py create mode 100644 vllm/distributed/weight_transfer/ipc_engine.py create mode 100644 vllm/distributed/weight_transfer/nccl_engine.py diff --git a/.gitignore b/.gitignore index 7cda86478664..9d09506ebd84 100644 --- a/.gitignore +++ b/.gitignore @@ -98,7 +98,7 @@ ipython_config.py **/generated/** # uv -uv.lock +# uv.lock # pyenv # For a library or package, you might want to ignore these files since the code is @@ -138,7 +138,7 @@ celerybeat.pid *.sage.py # Environments -.env +# .env .venv env/ venv/ diff --git a/examples/offline_inference/new_weight_syncing/rlhf.py b/examples/offline_inference/new_weight_syncing/rlhf.py new file mode 100644 index 000000000000..6f69caae2b84 --- /dev/null +++ b/examples/offline_inference/new_weight_syncing/rlhf.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray, +with new weight syncing APIs. + +The script separates training and inference workloads onto distinct GPUs +so that Ray can manage process placement and inter-process communication. +A Hugging Face Transformer model occupies GPU 0 for training, whereas a +tensor-parallel vLLM inference engine occupies GPU 1–2. + +The example performs the following steps: + +* Load the training model on GPU 0. +* Split the inference model across GPUs 1–2 using vLLM's tensor parallelism + and Ray placement groups. +* Generate text from a list of prompts using the inference engine. +* Update the weights of the training model and broadcast the updated weights + to the inference engine by using a Ray collective RPC group. Note that + for demonstration purposes we simply zero out the weights. + +This example assumes a single-node cluster with three GPUs, but Ray +supports multi-node clusters. vLLM expects the GPUs are only used for vLLM +workloads. Residual GPU activity interferes with vLLM memory profiling and +causes unexpected behavior. +""" + +import os +from dataclasses import asdict + +import ray +import torch +from ray.util.placement_group import placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from rlhf_utils import stateless_init_process_group +from transformers import AutoModelForCausalLM + +from vllm import LLM, SamplingParams +from vllm.config import WeightTransferConfig +from vllm.distributed.weight_transfer.base import ( + WeightTransferInitRequest, + WeightUpdateRequest, +) +from vllm.distributed.weight_transfer.nccl_engine import NCCLInitInfo, NCCLUpdateInfo +from vllm.utils.network_utils import get_ip, get_open_port + +MODEL_NAME = "Qwen/Qwen3-30B-A3B-Thinking-2507" + + +class MyLLM(LLM): + """Configure the vLLM worker for Ray placement group execution.""" + + def __init__(self, *args, **kwargs): + # Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray + # so that vLLM can manage its own device placement within the worker. + os.environ.pop("CUDA_VISIBLE_DEVICES", None) + super().__init__(*args, **kwargs) + + +# Load the OPT-125M model onto GPU 0 for the training workload. +train_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16) +train_model.to("cuda:0") + +# Initialize Ray and set the visible devices. The vLLM engine will +# be placed on GPUs 1 and 2. +os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" +ray.init(runtime_env={"excludes": [".git/objects/pack/"]}) +# ray.init() + +# Create a placement group that reserves GPU 1–2 for the vLLM inference engine. +# Learn more about Ray placement groups: +# https://docs.ray.io/en/latest/placement-groups.html +pg_training = placement_group([{"GPU": 1, "CPU": 0}]) +ray.get(pg_training.ready()) + +pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2) +ray.get(pg_inference.ready()) +scheduling_inference = PlacementGroupSchedulingStrategy( + placement_group=pg_inference, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=0, +) + +# Launch the vLLM inference engine. The `enforce_eager` flag reduces +# start-up latency. +# Note: Weight transfer APIs (init_weight_transfer, update_weights, +# finalize_weight_update) are now native to vLLM workers. +llm = ray.remote( + num_cpus=0, + num_gpus=0, + scheduling_strategy=scheduling_inference, +)(MyLLM).remote( + model=MODEL_NAME, + enforce_eager=True, + tensor_parallel_size=1, + data_parallel_size=2, + distributed_executor_backend="ray", + weight_transfer_config=WeightTransferConfig(backend="nccl"), + enable_expert_parallel=True, + all2all_backend="pplx", +) + +# Generate text from the prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +sampling_params = SamplingParams(temperature=0) + +outputs = ray.get(llm.generate.remote(prompts, sampling_params)) + +print("-" * 50) +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) + +# Set up the communication channel between the training process and the +# inference engine. +master_address = get_ip() +master_port = get_open_port() + +handle = llm.init_weight_transfer.remote( + WeightTransferInitRequest( + init_info=asdict( + NCCLInitInfo( + master_address=master_address, + master_port=master_port, + rank_offset=1, + world_size=3, + ) + ) + ) +) + +model_update_group = stateless_init_process_group( + master_address, master_port, 0, 3, torch.device("cuda:0") +) +ray.get(handle) + +# Simulate a training step by zeroing out all model weights. +# In a real RLHF training loop the weights would be updated using the gradient +# from an RL objective such as PPO on a reward model. +for name, p in train_model.named_parameters(): + p.data.zero_() + +# Synchronize the updated weights to the inference engine using batched API. +# Collect all weight metadata +names = [] +dtype_names = [] +shapes = [] +for name, p in train_model.named_parameters(): + names.append(name) + dtype_names.append(str(p.dtype).split(".")[-1]) + shapes.append(p.shape) + +# Issue update_weights call with NCCL-specific update info +handle = llm.update_weights.remote( + WeightUpdateRequest( + update_info=asdict( + NCCLUpdateInfo( + names=names, + dtype_names=dtype_names, + shapes=shapes, + ) + ) + ) +) + +# Broadcast all weights from trainer +for name, p in train_model.named_parameters(): + model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream()) + +ray.get(handle) + +# Finalize the weight update (processes weights for quantization/kernel format) +ray.get(llm.finalize_weight_update.remote()) + +# Generate text with the updated model. The output is expected to be nonsense +# because the weights are zero. +outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params)) +print("-" * 50) +for output in outputs_updated: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) diff --git a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py new file mode 100644 index 000000000000..0bd38428d5b5 --- /dev/null +++ b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py @@ -0,0 +1,248 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Demonstrates asynchronous reinforcement learning from human feedback (RLHF) +using vLLM and Ray, with the new weight syncing APIs + +The script separates training and inference workloads onto distinct GPUs +so that Ray can manage process placement and inter-process communication. +A Hugging Face Transformer model occupies GPU 0 for training, whereas a +tensor-parallel vLLM inference engine occupies GPU 1–2. + +The example performs the following steps: + +* Load the training model on GPU 0. +* Split the inference model across GPUs 1–2 using vLLM's tensor parallelism + and Ray placement groups. +* Start generation from a list of prompts using the inference engine. +* Pause generation once generation completes for one sequence +* Update the weights of the training model and broadcast the updated weights + to the inference engine by using a Ray collective RPC group. Note that + for demonstration purposes we simply zero out the weights. +* Resume generation and print out the results + +This example assumes a single-node cluster with three GPUs, but Ray +supports multi-node clusters. vLLM expects the GPUs are only used for vLLM +workloads. Residual GPU activity interferes with vLLM memory profiling and +causes unexpected behavior. +""" + +import asyncio +import os +import uuid + +import ray +import torch +from ray.util.placement_group import placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from rlhf_utils import stateless_init_process_group +from transformers import AutoModelForCausalLM + +import vllm +from vllm import SamplingParams +from vllm.config import WeightTransferConfig +from vllm.distributed.weight_transfer.base import ( + WeightTransferInitRequest, + WeightUpdateRequest, +) +from vllm.utils.network_utils import get_ip, get_open_port + + +class MyLLM: + """Simple wrapper over AsyncLLM for supporting async RL.""" + + def __init__(self, **kwargs): + self.engine = vllm.AsyncLLMEngine.from_engine_args( + vllm.AsyncEngineArgs(**kwargs) + ) + self.generation_paused_event = asyncio.Event() + + async def generate( + self, prompt: str, sampling_params: vllm.SamplingParams + ) -> vllm.RequestOutput: + async for request_output in self.engine.generate( + prompt, sampling_params, request_id=str(uuid.uuid4()) + ): + final_output = request_output + return final_output + + async def generate_with_retry( + self, prompt: str, sampling_params: vllm.SamplingParams + ) -> vllm.RequestOutput: + finish_reason = "abort" + while finish_reason == "abort": + await self._wait_for_generation_to_resume() + output = await self.generate(prompt, sampling_params) + finish_reason = output.outputs[0].finish_reason + if finish_reason == "abort": + print(f"REQ ABORTED, prompt: {prompt}, text: {output.outputs[0].text}") + prompt += output.outputs[0].text + return output + + async def abort_generation(self) -> None: + self.generation_paused_event.set() + unfinished_request_ids = list( + self.engine.output_processor.request_states.keys() + ) + if unfinished_request_ids: + await self.engine.abort(unfinished_request_ids) + await self.engine.reset_prefix_cache() + print( + f"abort_generation() finished, aborted" + f"{len(unfinished_request_ids)} requests" + ) + + async def resume_generation(self) -> None: + self.generation_paused_event.clear() + + async def collective_rpc(self, method: str, args: tuple = ()): + return await self.engine.collective_rpc(method, args=args) + + async def _wait_for_generation_to_resume(self) -> None: + """Waits for generation to be resumed, intended for in-flight weight updates + and partial rollouts.""" + while self.generation_paused_event.is_set(): + await asyncio.sleep(0.5) + + async def init_weight_transfer(self, request: WeightTransferInitRequest) -> None: + print("reached init weight transfer") + return await self.engine.init_weight_transfer(request) + + async def update_weights(self, request: WeightUpdateRequest) -> None: + return await self.engine.update_weights(request) + + async def finalize_weight_update(self) -> None: + return await self.engine.finalize_weight_update() + + +# Load the OPT-125M model onto GPU 0 for the training workload. +train_model = AutoModelForCausalLM.from_pretrained( + "facebook/opt-125m", dtype=torch.bfloat16 +) +train_model.to("cuda:0") + +# Initialize Ray and set the visible devices. The vLLM engine will +# be placed on GPUs 1 and 2. +os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" +ray.init(runtime_env={"excludes": [".git/objects/pack/"]}) +# ray.init() + +# Create a placement group that reserves GPU 1–2 for the vLLM inference engine. +# Learn more about Ray placement groups: +# https://docs.ray.io/en/latest/placement-groups.html +pg_training = placement_group([{"GPU": 1, "CPU": 0}]) +ray.get(pg_training.ready()) + +pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2) +ray.get(pg_inference.ready()) +scheduling_inference = PlacementGroupSchedulingStrategy( + placement_group=pg_inference, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=0, +) + +# Launch the vLLM inference engine. The `enforce_eager` flag reduces +# start-up latency. +# Note: Weight transfer APIs (init_weight_transfer, update_weights, +# finalize_weight_update) are now native to vLLM workers. +llm = ray.remote( + num_cpus=0, + num_gpus=0, + scheduling_strategy=scheduling_inference, +)(MyLLM).remote( + model="facebook/opt-125m", + enforce_eager=True, + tensor_parallel_size=2, + distributed_executor_backend="ray", + weight_transfer_config=WeightTransferConfig(backend="nccl"), +) + +# Generate text from the prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +sampling_params = SamplingParams(temperature=0) + +# Set up the communication channel between the training process and the +# inference engine. +master_address = get_ip() +master_port = get_open_port() + +print("reached init weight in driver") +handle = llm.init_weight_transfer.remote( + WeightTransferInitRequest( + init_info=dict( + master_address=master_address, + master_port=master_port, + rank_offset=1, + world_size=3, + ) + ) +) + +model_update_group = stateless_init_process_group( + master_address, master_port, 0, 3, torch.device("cuda:0") +) +ray.get(handle) + + +generation_futures = [ + llm.generate_with_retry.remote(prompt, sampling_params) for prompt in prompts +] + +finished, pending = ray.wait(generation_futures, num_returns=1) + +# Abort generation in preparation for weight sync +ray.get(llm.abort_generation.remote()) + +# Simulate a training step by zeroing out all model weights. +# In a real RLHF training loop the weights would be updated using the gradient +# from an RL objective such as PPO on a reward model. +for name, p in train_model.named_parameters(): + p.data.zero_() + +# Synchronize the updated weights to the inference engine using batched API. +# Collect all weight metadata +names = [] +dtype_names = [] +shapes = [] +for name, p in train_model.named_parameters(): + names.append(name) + dtype_names.append(str(p.dtype).split(".")[-1]) + shapes.append(p.shape) + +# Issue update_weights call +handle = llm.update_weights.remote( + WeightUpdateRequest( + update_info=dict(names=names, dtype_names=dtype_names, shapes=shapes) + ) +) + +# Broadcast all weights from trainer +for name, p in train_model.named_parameters(): + model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream()) + +ray.get(handle) + +# Finalize the weight update (processes weights for quantization/kernel format) +ray.get(llm.finalize_weight_update.remote()) + +# Resume generation since weight sync is complete +ray.get(llm.resume_generation.remote()) + +# Get all outputs +outputs = ray.get(finished) + ray.get(pending) + +# We expect the first output to be normal generation. +# The other outputs should have generated regular results midway +# and then have garbage tokens because we zero'd out the weights +print("-" * 50) +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) diff --git a/examples/offline_inference/new_weight_syncing/rlhf_http.py b/examples/offline_inference/new_weight_syncing/rlhf_http.py new file mode 100644 index 000000000000..cd321bb5f895 --- /dev/null +++ b/examples/offline_inference/new_weight_syncing/rlhf_http.py @@ -0,0 +1,224 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Demonstrates reinforcement learning from human feedback (RLHF) using vLLM +via HTTP API, with new weight syncing APIs. + +Unlike rlhf.py which creates a vLLM instance programmatically, this script +assumes you have already started a vLLM server using `vllm serve`. It uses: +- OpenAI-compatible API for inference requests +- HTTP endpoints for weight transfer control plane +- NCCL for actual weight data transfer + +Prerequisites: + Start a vLLM server with weight transfer enabled: + + $ vllm serve facebook/opt-125m \ + --enforce-eager \ + --weight-transfer-backend nccl + + Then run this script: + + $ python rlhf_http.py + +The example performs the following steps: + +* Load the training model on GPU 0. +* Generate text using the vLLM server via OpenAI-compatible API. +* Initialize weight transfer via HTTP endpoint. +* Update the weights of the training model and broadcast the updated weights + to the vLLM server using NCCL. Note that for demonstration purposes we + simply zero out the weights. +* Generate text again to show the effect of the weight update. +""" + +from dataclasses import asdict + +import requests +import torch +from openai import OpenAI +from rlhf_utils import stateless_init_process_group +from transformers import AutoModelForCausalLM + +from vllm.distributed.weight_transfer.nccl_engine import NCCLInitInfo, NCCLUpdateInfo +from vllm.utils.network_utils import get_ip, get_open_port + +BASE_URL = "http://localhost:8000" +MODEL_NAME = "Qwen/Qwen3-30B-A3B-Thinking-2507" + + +INFERENCE_WORLD_SIZE = 2 +WORLD_SIZE = INFERENCE_WORLD_SIZE + 1 + +DEVICE = f"cuda:{INFERENCE_WORLD_SIZE}" + + +def generate_completions(client: OpenAI, model: str, prompts: list[str]) -> list[str]: + """Generate completions using the OpenAI-compatible API.""" + results = [] + for prompt in prompts: + response = client.completions.create( + model=model, + prompt=prompt, + max_tokens=32, + temperature=0, + ) + results.append(response.choices[0].text) + return results + + +def init_weight_transfer( + base_url: str, + master_address: str, + master_port: int, + rank_offset: int, + world_size: int, +) -> None: + """Initialize weight transfer via HTTP endpoint.""" + url = f"{base_url}/init_weight_transfer" + payload = { + "init_info": asdict( + NCCLInitInfo( + master_address=master_address, + master_port=master_port, + rank_offset=rank_offset, + world_size=world_size, + ) + ) + } + response = requests.post(url, json=payload, timeout=60) + response.raise_for_status() + + +def update_weights( + base_url: str, + names: list[str], + dtype_names: list[str], + shapes: list[list[int]], +) -> None: + """Update weights via HTTP endpoint.""" + url = f"{base_url}/update_weights" + payload = { + "update_info": asdict( + NCCLUpdateInfo( + names=names, + dtype_names=dtype_names, + shapes=shapes, + ) + ) + } + response = requests.post(url, json=payload, timeout=300) + response.raise_for_status() + + +def finalize_weight_update(base_url: str) -> None: + """Finalize weight update via HTTP endpoint.""" + url = f"{base_url}/finalize_weight_update" + response = requests.post(url, timeout=60) + response.raise_for_status() + + +def main(): + # Load the training model + print(f"Loading training model: {MODEL_NAME}") + train_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16) + train_model.to(DEVICE) + + # Create OpenAI client pointing to the vLLM server + client = OpenAI( + base_url=f"{BASE_URL}/v1", + api_key="EMPTY", # vLLM doesn't require an API key by default + ) + + # Test prompts + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + # Generate text before weight update + print("-" * 50) + print("Generating text BEFORE weight update:") + print("-" * 50) + outputs = generate_completions(client, MODEL_NAME, prompts) + for prompt, generated_text in zip(prompts, outputs): + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) + + # Set up the communication channel between the training process and the + # vLLM server. The trainer is rank 0, vLLM worker(s) start at rank_offset. + master_address = get_ip() + master_port = get_open_port() + rank_offset = 1 + + print(f"Initializing weight transfer: master={master_address}:{master_port}") + + # Initialize weight transfer on vLLM server (this is async, server will + # wait for NCCL connection) + import threading + + init_thread = threading.Thread( + target=init_weight_transfer, + args=(BASE_URL, master_address, master_port, rank_offset, WORLD_SIZE), + ) + init_thread.start() + + # Initialize NCCL process group on trainer side + device = torch.device(DEVICE) + model_update_group = stateless_init_process_group( + master_address, master_port, 0, WORLD_SIZE, device + ) + + # Wait for init_weight_transfer to complete + init_thread.join() + + # Simulate a training step by zeroing out all model weights. + # In a real RLHF training loop the weights would be updated using the + # gradient from an RL objective such as PPO on a reward model. + print("Simulating training step (zeroing out weights)...") + for name, p in train_model.named_parameters(): + p.data.zero_() + + # Collect weight metadata for the update request + names = [] + dtype_names = [] + shapes = [] + for name, p in train_model.named_parameters(): + names.append(name) + dtype_names.append(str(p.dtype).split(".")[-1]) + shapes.append(list(p.shape)) + + # Start the update_weights call in a separate thread since it will block + # waiting for NCCL broadcasts + update_thread = threading.Thread( + target=update_weights, + args=(BASE_URL, names, dtype_names, shapes), + ) + update_thread.start() + + # Broadcast all weights from trainer to vLLM workers + print("Broadcasting weights via NCCL...") + for name, p in train_model.named_parameters(): + model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream()) + + # Wait for update_weights to complete + update_thread.join() + + # Finalize the weight update (processes weights for quantization/kernel format) + finalize_weight_update(BASE_URL) + + # Generate text after weight update. The output is expected to be nonsense + # because the weights are zero. + print("-" * 50) + print("Generating text AFTER weight update (expect nonsense):") + print("-" * 50) + outputs_updated = generate_completions(client, MODEL_NAME, prompts) + for prompt, generated_text in zip(prompts, outputs_updated): + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/new_weight_syncing/rlhf_ipc.py b/examples/offline_inference/new_weight_syncing/rlhf_ipc.py new file mode 100644 index 000000000000..7ed0a4dbf9af --- /dev/null +++ b/examples/offline_inference/new_weight_syncing/rlhf_ipc.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray, +with new weight syncing APIs + +The script colocates the training and inference workloads onto the same GPU using Ray. + +The example performs the following steps: + +* Request a placement group of 1 GPU. +* Place the inference model on the above GPU using the placement group. +* Place and load the training model on the same GPU using the placement group. +* Generate text from a list of prompts using the inference engine. +* Update the weights of the training model and broadcast the updated weights + to the inference engine by using a Ray collective RPC group. Note that + for demonstration purposes we simply zero out the weights. + +This example assumes a single-node cluster with a single GPUs, +but can be extended to multiple GPUs. +""" + +import os +from dataclasses import asdict + +import ray +import torch +from ray.util.placement_group import placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from transformers import AutoModelForCausalLM + +from vllm import LLM, SamplingParams +from vllm.config import WeightTransferConfig +from vllm.distributed.weight_transfer.base import ( + WeightTransferInitRequest, + WeightUpdateRequest, +) +from vllm.distributed.weight_transfer.ipc_engine import IPCUpdateInfo + + +class MyLLM(LLM): + """Configure the vLLM worker for Ray placement group execution.""" + + def __init__(self, *args, **kwargs): + # Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray + # so that vLLM can manage its own device placement within the worker. + os.environ.pop("CUDA_VISIBLE_DEVICES", None) + # Each worker uses 0.4 GPU so that two instances fit on the same GPUs. + os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4" + os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0" + # needed for ipc handle serialization + os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" + super().__init__(*args, **kwargs) + + +def get_physical_gpu_id(): + device = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device) + return str(props.uuid) + + +# Load the OPT-125M model onto GPU 0 for the training workload. + +# MODEL_NAME = "facebook/opt-125m" +MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" + + +@ray.remote +class TrainModel: + def __init__(self, llm_handle: ray.ObjectRef): + self.train_model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + ) + self.train_model.to("cuda:0") + self.llm_handle = llm_handle + + def init_weight_transfer(self): + self.llm_handle.init_weight_transfer.remote(WeightTransferInitRequest()) + + def broadcast_weights(self, llm_handle: ray.ObjectRef): + self.llm_handle = llm_handle + names, dtypes, shapes, ipc_handles = [], [], [], [] + + for name, p in self.train_model.named_parameters(): + names.append(name) + dtypes.append(str(p.dtype).split(".")[-1]) + shapes.append(p.shape) + + from torch.multiprocessing.reductions import reduce_tensor + + weight = p.detach().contiguous() + ipc_handle = reduce_tensor(weight) + ipc_handle = {get_physical_gpu_id(): ipc_handle} + ipc_handles.append(ipc_handle) + + ray.get( + self.llm_handle.update_weights.remote( + WeightUpdateRequest( + update_info=asdict( + IPCUpdateInfo( + names=names, + dtype_names=dtypes, + shapes=shapes, + ipc_handles=ipc_handles, + ) + ) + ) + ) + ) + + def zero_data(self): + # Simulate a training step by zeroing out all model weights. + # In a real RLHF training loop the weights would be updated using the gradient + # from an RL objective such as PPO on a reward model. + for name, p in self.train_model.named_parameters(): + p.data.zero_() + + +ray.init(runtime_env={"excludes": [".git/objects/pack/"]}) + +# Create a placement group that reserves GPU 1–2 for the vLLM inference engine. +# Learn more about Ray placement groups: +# https://docs.ray.io/en/latest/placement-groups.html + +pg_colocate = placement_group([{"GPU": 1, "CPU": 0}]) +ray.get(pg_colocate.ready()) + + +llm = ray.remote( + num_cpus=0, + num_gpus=0, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg_colocate, + placement_group_capture_child_tasks=True, + ), +)(MyLLM).remote( + model=MODEL_NAME, + enforce_eager=True, + tensor_parallel_size=1, + distributed_executor_backend="ray", + gpu_memory_utilization=0.7, + weight_transfer_config=WeightTransferConfig(backend="ipc"), +) + +train_model = TrainModel.options( + num_gpus=0.1, + num_cpus=0, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg_colocate, placement_group_capture_child_tasks=True + ), +).remote(llm) + + +# Generate text from the prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +sampling_params = SamplingParams(temperature=0) + +outputs = ray.get(llm.generate.remote(prompts, sampling_params)) + +print("-" * 50) +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) + +ray.get(train_model.init_weight_transfer.remote()) + +train_model.zero_data.remote() + +# Synchronize the updated weights to the inference engine using batched API. +ray.get(train_model.broadcast_weights.remote(llm)) + +# Finalize the weight update (processes weights for quantization/kernel format) +ray.get(llm.finalize_weight_update.remote()) + +# Generate text with the updated model. The output is expected to be nonsense +# because the weights are zero. +outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params)) +print("-" * 50) +for output in outputs_updated: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) diff --git a/examples/offline_inference/new_weight_syncing/rlhf_utils.py b/examples/offline_inference/new_weight_syncing/rlhf_utils.py new file mode 100644 index 000000000000..35761ae3996d --- /dev/null +++ b/examples/offline_inference/new_weight_syncing/rlhf_utils.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +def stateless_init_process_group(master_address, master_port, rank, world_size, device): + """ + vLLM provides `StatelessProcessGroup` to create a process group + without considering the global process group in torch.distributed. + It is recommended to create `StatelessProcessGroup`, and then initialize + the data-plane communication (NCCL) between external (train processes) + and vLLM workers. + """ + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.utils import StatelessProcessGroup + + pg = StatelessProcessGroup.create( + host=master_address, port=master_port, rank=rank, world_size=world_size + ) + pynccl = PyNcclCommunicator(pg, device=device) + return pynccl diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 3c77fad41d07..16cc6206db54 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -45,6 +45,7 @@ get_layers_from_vllm_config, set_current_vllm_config, ) +from vllm.config.weight_transfer import WeightTransferConfig # __all__ should only contain classes and functions. # Types and globals should be imported from their respective modules. @@ -107,4 +108,5 @@ "get_current_vllm_config", "set_current_vllm_config", "get_layers_from_vllm_config", + "WeightTransferConfig", ] diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 30a24233575f..9819c9a0df5a 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -44,6 +44,7 @@ from .speculative import SpeculativeConfig from .structured_outputs import StructuredOutputsConfig from .utils import SupportsHash, config +from .weight_transfer import WeightTransferConfig if TYPE_CHECKING: from transformers import PretrainedConfig @@ -242,6 +243,11 @@ class VllmConfig: performance. -02 is used by defult. See OptimizationLevel for full description.""" + weight_transfer_config: WeightTransferConfig = Field( + default_factory=WeightTransferConfig + ) + """The configurations for weight transfer during RL training.""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/config/weight_transfer.py b/vllm/config/weight_transfer.py new file mode 100644 index 000000000000..370652e2cbb5 --- /dev/null +++ b/vllm/config/weight_transfer.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Literal + +from vllm.config.utils import config + + +@config +@dataclass +class WeightTransferConfig: + """Configuration for weight transfer during RL training.""" + + backend: Literal["nccl", "ipc", "rdma"] = "nccl" + """The backend to use for weight transfer.""" diff --git a/vllm/distributed/weight_transfer/__init__.py b/vllm/distributed/weight_transfer/__init__.py new file mode 100644 index 000000000000..bd06049e3bc4 --- /dev/null +++ b/vllm/distributed/weight_transfer/__init__.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Weight transfer engines for syncing model weights from trainers +to inference workers. +""" + +from vllm.config.parallel import ParallelConfig +from vllm.config.weight_transfer import WeightTransferConfig +from vllm.distributed.weight_transfer.base import ( + WeightTransferEngine, + WeightUpdateRequest, +) +from vllm.distributed.weight_transfer.ipc_engine import ( + IPCWeightTransferEngine, +) +from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLWeightTransferEngine, +) + +WEIGHT_TRANSFER_ENGINE_REGISTRY = { + "nccl": NCCLWeightTransferEngine, + "ipc": IPCWeightTransferEngine, +} + + +def register_weight_transfer_engine( + name: str, engine: type[WeightTransferEngine] +) -> None: + if name in WEIGHT_TRANSFER_ENGINE_REGISTRY: + raise ValueError(f"Weight transfer engine {name} already registered") + WEIGHT_TRANSFER_ENGINE_REGISTRY[name] = engine + + +def init_transfer_engine(config: WeightTransferConfig, parallel_config: ParallelConfig): + if config.backend not in WEIGHT_TRANSFER_ENGINE_REGISTRY: + raise ValueError(f"Invalid weight transfer backend: {config.backend}") + + engine_cls = WEIGHT_TRANSFER_ENGINE_REGISTRY[config.backend] + return engine_cls(config, parallel_config) + + +__all__ = [ + "WeightTransferEngine", + "NCCLWeightTransferEngine", + "register_weight_transfer_engine", + "WEIGHT_TRANSFER_ENGINE_MAP", + "IPCWeightTransferEngine", + "WeightUpdateRequest", +] diff --git a/vllm/distributed/weight_transfer/base.py b/vllm/distributed/weight_transfer/base.py new file mode 100644 index 000000000000..9251405939db --- /dev/null +++ b/vllm/distributed/weight_transfer/base.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Base class for weight transfer engines.""" + +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any, Generic, TypeVar + +import torch + +from vllm.config.parallel import ParallelConfig +from vllm.config.weight_transfer import WeightTransferConfig + +TInitInfo = TypeVar("TInitInfo", bound="BackendInitInfo") +TUpdateInfo = TypeVar("TUpdateInfo", bound="BackendUpdateInfo") + + +# Base protocols for backend-specific dataclasses +@dataclass +class BackendInitInfo(ABC): # noqa: B024 + """Base class for backend-specific initialization info.""" + + pass + + +@dataclass +class BackendUpdateInfo(ABC): # noqa: B024 + """Base class for backend-specific weight update info.""" + + pass + + +# API-level request classes (accept dicts for backend-agnostic serialization) +@dataclass +class WeightTransferInitRequest: + """API-level weight transfer initialization request.""" + + init_info: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class WeightUpdateRequest: + """API-level weight update request.""" + + update_info: dict[str, Any] = field(default_factory=dict) + + +class WeightTransferEngine(ABC, Generic[TInitInfo, TUpdateInfo]): + """ + Base class for weight transfer engines that handle transport of model weights + from a trainer to inference workers. + + This abstraction separates weight transfer transport logic from the worker + implementation, allowing different backends (NCCL, CUDA IPC, RDMA) to be + plugged in. + + Subclasses should define: + init_info_cls: Type of backend-specific initialization info + update_info_cls: Type of backend-specific update info + """ + + # Subclasses should override these class attributes + init_info_cls: type[TInitInfo] + update_info_cls: type[TUpdateInfo] + + def __init__( + self, config: WeightTransferConfig, parallel_config: ParallelConfig + ) -> None: + """ + Initialize the weight transfer engine. + + Args: + config: The configuration for the weight transfer engine + parallel_config: The configuration for the parallel setup + """ + self.config = config + self.parallel_config = parallel_config + + def parse_init_info(self, init_dict: dict[str, Any]) -> TInitInfo: + """ + Construct typed init info from dict with validation. + + Args: + init_dict: Dictionary containing backend-specific initialization parameters + + Returns: + Typed backend-specific init info dataclass + + Raises: + ValueError: If init_dict is invalid for this backend + """ + try: + return self.init_info_cls(**init_dict) + except TypeError as e: + raise ValueError( + f"Invalid init_info for {self.__class__.__name__}: {e}" + ) from e + + def parse_update_info(self, update_dict: dict[str, Any]) -> TUpdateInfo: + """ + Construct typed update info from dict with validation. + + Args: + update_dict: Dictionary containing backend-specific update parameters + + Returns: + Typed backend-specific update info dataclass + + Raises: + ValueError: If update_dict is invalid for this backend + """ + try: + return self.update_info_cls(**update_dict) + except TypeError as e: + raise ValueError( + f"Invalid update_info for {self.__class__.__name__}: {e}" + ) from e + + @abstractmethod + def init_transfer(self, init_info: TInitInfo) -> None: + """ + Initialize the weight transfer mechanism. + This is called once at the beginning of training. + + Args: + init_info: Backend-specific initialization info + """ + raise NotImplementedError + + @abstractmethod + def receive_weights( + self, + update_info: TUpdateInfo, + load_weights: Callable[[list[tuple[str, torch.Tensor]]], None], + ) -> None: + """ + Receive weights from the trainer and load them incrementally. + + Args: + update_info: Backend-specific update info containing parameter metadata + and any backend-specific data (e.g., IPC handles) + load_weights: Callable that loads weights into the model. Called + incrementally for each weight to avoid OOM. + """ + raise NotImplementedError + + @abstractmethod + def shutdown(self) -> None: + """ + Shutdown the weight transfer engine. + This should be called when the worker is shutting down. + """ + raise NotImplementedError diff --git a/vllm/distributed/weight_transfer/ipc_engine.py b/vllm/distributed/weight_transfer/ipc_engine.py new file mode 100644 index 000000000000..d13b3fab4a3b --- /dev/null +++ b/vllm/distributed/weight_transfer/ipc_engine.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable +from dataclasses import dataclass + +import torch + +from vllm.config.parallel import ParallelConfig +from vllm.config.weight_transfer import WeightTransferConfig +from vllm.distributed.weight_transfer.base import ( + BackendInitInfo, + BackendUpdateInfo, + WeightTransferEngine, +) + + +@dataclass +class IPCInitInfo(BackendInitInfo): + """Initialization info for IPC weight transfer backend. No init needed for IPC.""" + + pass + + +@dataclass +class IPCUpdateInfo(BackendUpdateInfo): + """Update info for IPC weight transfer backend.""" + + names: list[str] + dtype_names: list[str] + shapes: list[list[int]] + ipc_handles: list[dict[str, tuple[Callable, tuple]]] + + def __post_init__(self): + """Validate that all lists have the same length.""" + num_params = len(self.names) + if len(self.dtype_names) != num_params: + raise ValueError( + f"`dtype_names` should be of the same size as `names`: " + f"got {len(self.dtype_names)} and {len(self.names)}" + ) + if len(self.shapes) != num_params: + raise ValueError( + f"`shapes` should be of the same size as `names`: " + f"got {len(self.shapes)} and {len(self.names)}" + ) + if len(self.ipc_handles) != num_params: + raise ValueError( + f"`ipc_handles` should be of the same size as `names`: " + f"got {len(self.ipc_handles)} and {len(self.names)}" + ) + + +class IPCWeightTransferEngine(WeightTransferEngine[IPCInitInfo, IPCUpdateInfo]): + """ + Weight transfer engine using CUDA IPC for communication between trainer and workers. + + This implementation uses CUDA IPC to transfer weights from the trainer (rank 0) + to all inference workers in a process group. + """ + + # Define backend-specific dataclass types + init_info_cls = IPCInitInfo + update_info_cls = IPCUpdateInfo + + def __init__( + self, config: WeightTransferConfig, parallel_config: ParallelConfig + ) -> None: + """ + Initialize the weight transfer engine. + + Args: + config: The configuration for the weight transfer engine + parallel_config: The configuration for the parallel setup + """ + super().__init__(config, parallel_config) + + def init_transfer(self, init_info: IPCInitInfo) -> None: + """ + Initialize the weight transfer mechanism. + This is called once at the beginning of training. + No initialization needed for IPC backend. + + Args: + init_info: IPC initialization info (empty) + """ + pass + + def receive_weights( + self, update_info: IPCUpdateInfo + ) -> list[tuple[str, torch.Tensor]]: + """ + Receive weights from the trainer via CUDA IPC handles. + + Args: + update_info: IPC update info containing parameter names, dtypes, shapes, + and IPC handles. Each IPC handle is a mapping between physical + GPU UUID and the IPC handle tuple (func, args). + + Returns: + List of (name, weight_tensor) tuples ready to be loaded into the model + """ + weights = [] + for name, _dtype_name, _shape, ipc_handle in zip( + update_info.names, + update_info.dtype_names, + update_info.shapes, + update_info.ipc_handles, + ): + device_index = torch.cuda.current_device() + props = torch.cuda.get_device_properties() + physical_gpu_id = str(props.uuid) + + handle = ipc_handle[physical_gpu_id] + + func, args = handle + list_args = list(args) # type: ignore + list_args[6] = device_index + weight = func(*list_args) # type: ignore + weights.append((name, weight)) + + return weights + + def shutdown(self) -> None: + """ + Shutdown the weight transfer engine. + """ + pass diff --git a/vllm/distributed/weight_transfer/nccl_engine.py b/vllm/distributed/weight_transfer/nccl_engine.py new file mode 100644 index 000000000000..d10c7c9b36c4 --- /dev/null +++ b/vllm/distributed/weight_transfer/nccl_engine.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""NCCL-based weight transfer engine.""" + +from collections.abc import Callable +from dataclasses import dataclass + +import torch + +from vllm.config.parallel import ParallelConfig +from vllm.config.weight_transfer import WeightTransferConfig +from vllm.distributed.weight_transfer.base import ( + BackendInitInfo, + BackendUpdateInfo, + WeightTransferEngine, +) + + +@dataclass +class NCCLInitInfo(BackendInitInfo): + """Initialization info for NCCL weight transfer backend.""" + + master_address: str + master_port: int + rank_offset: int + world_size: int + + +@dataclass +class NCCLUpdateInfo(BackendUpdateInfo): + """Update info for NCCL weight transfer backend.""" + + names: list[str] + dtype_names: list[str] + shapes: list[list[int]] + + def __post_init__(self): + """Validate that all lists have the same length.""" + num_params = len(self.names) + if len(self.dtype_names) != num_params: + raise ValueError( + f"`dtype_names` should be of the same size as `names`: " + f"got {len(self.dtype_names)} and {len(self.names)}" + ) + if len(self.shapes) != num_params: + raise ValueError( + f"`shapes` should be of the same size as `names`: " + f"got {len(self.shapes)} and {len(self.names)}" + ) + + +class NCCLWeightTransferEngine(WeightTransferEngine[NCCLInitInfo, NCCLUpdateInfo]): + """ + Weight transfer engine using NCCL for communication between trainer and workers. + + This implementation uses NCCL broadcast operations to transfer weights from + the trainer (rank 0) to all inference workers in a process group. + """ + + # Define backend-specific dataclass types + init_info_cls = NCCLInitInfo + update_info_cls = NCCLUpdateInfo + + def __init__( + self, config: WeightTransferConfig, parallel_config: ParallelConfig + ) -> None: + """ + Initialize the NCCL weight transfer engine. + + Args: + config: The configuration for the weight transfer engine + parallel_config: The configuration for the parallel setup + """ + super().__init__(config, parallel_config) + self.model_update_group = None + + def init_transfer(self, init_info: NCCLInitInfo) -> None: + """ + Initialize NCCL process group with the trainer. + + Args: + init_info: NCCL initialization info containing master address, port, + rank offset, and world size + """ + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.utils import StatelessProcessGroup + + # Calculate the global rank in the trainer-worker process group + # Must account for data parallel to get unique ranks across all workers + dp_rank = self.parallel_config.data_parallel_rank + world_size_per_dp = self.parallel_config.world_size # TP * PP + tp_rank = self.parallel_config.rank + + # Unique rank across all DP groups + worker_rank = dp_rank * world_size_per_dp + tp_rank + rank = worker_rank + init_info.rank_offset + # Create stateless process group + pg = StatelessProcessGroup.create( + host=init_info.master_address, + port=init_info.master_port, + rank=rank, + world_size=init_info.world_size, + ) + + # Initialize NCCL communicator + self.model_update_group = PyNcclCommunicator( + pg, device=torch.cuda.current_device() + ) + + def receive_weights( + self, + update_info: NCCLUpdateInfo, + load_weights: Callable[[list[tuple[str, torch.Tensor]]], None], + ) -> None: + """ + Receive weights from trainer via NCCL broadcast and load them incrementally. + + Args: + update_info: NCCL update info containing parameter names, dtypes, and shapes + load_weights: Callable that loads weights into the model. Called + incrementally for each weight to avoid OOM. + """ + if self.model_update_group is None: + raise RuntimeError( + "NCCL weight transfer not initialized. Call init_transfer() first." + ) + + for name, dtype_name, shape in zip( + update_info.names, update_info.dtype_names, update_info.shapes + ): + # Get the torch dtype + dtype = getattr(torch, dtype_name) + + # Allocate buffer for receiving weight + weight = torch.empty(shape, dtype=dtype, device="cuda") + + # Broadcast from rank 0 (trainer) + self.model_update_group.broadcast( + weight, src=0, stream=torch.cuda.current_stream() + ) + + # Load weight immediately to avoid accumulating all weights in memory + load_weights([(name, weight)]) + + # Clean up the weight tensor + del weight + + def shutdown(self) -> None: + if self.model_update_group is not None: + self.model_update_group.destroy() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d5bec7e73e1e..c78494c9a975 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -8,7 +8,7 @@ import json import sys from collections.abc import Callable -from dataclasses import MISSING, dataclass, fields, is_dataclass +from dataclasses import MISSING, dataclass, field, fields, is_dataclass from itertools import permutations from types import UnionType from typing import ( @@ -55,6 +55,7 @@ SpeculativeConfig, StructuredOutputsConfig, VllmConfig, + WeightTransferConfig, get_attr_docs, ) from vllm.config.cache import ( @@ -239,17 +240,17 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]: # Save time only getting attr docs if we're generating help text cls_docs = get_attr_docs(cls) if NEEDS_HELP else {} kwargs = {} - for field in fields(cls): + for fld in fields(cls): # Get the set of possible types for the field - type_hints: set[TypeHint] = get_type_hints(field.type) + type_hints: set[TypeHint] = get_type_hints(fld.type) # If the field is a dataclass, we can use the model_validate_json generator = (th for th in type_hints if is_dataclass(th)) dataclass_cls = next(generator, None) # Get the default value of the field - if field.default is not MISSING: - default = field.default + if fld.default is not MISSING: + default = fld.default # Handle pydantic.Field defaults if isinstance(default, FieldInfo): if default.default_factory is None: @@ -259,11 +260,11 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]: # These could emit logs on init, which would be confusing. with suppress_logging(): default = default.default_factory() - elif field.default_factory is not MISSING: - default = field.default_factory() + elif fld.default_factory is not MISSING: + default = fld.default_factory() # Get the help text for the field - name = field.name + name = fld.name help = cls_docs.get(name, "").strip() # Escape % for argparse help = help.replace("%", "%%") @@ -577,6 +578,10 @@ class EngineArgs: ) tokens_only: bool = False + weight_transfer_config: WeightTransferConfig = field( + default_factory=WeightTransferConfig + ) + def __post_init__(self): # support `EngineArgs(compilation_config={...})` # without having to manually construct a @@ -1754,6 +1759,7 @@ def create_engine_config( profiler_config=self.profiler_config, additional_config=self.additional_config, optimization_level=self.optimization_level, + weight_transfer_config=self.weight_transfer_config, ) return config diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index bf656cf23de6..9f51a7c4e769 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -6,6 +6,10 @@ from typing import Any from vllm.config import ModelConfig, VllmConfig +from vllm.distributed.weight_transfer.base import ( + WeightTransferInitRequest, + WeightUpdateRequest, +) from vllm.inputs.data import PromptType from vllm.lora.request import LoRARequest from vllm.outputs import PoolingRequestOutput, RequestOutput @@ -192,3 +196,17 @@ async def collective_rpc( async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: """Get supported tasks""" raise NotImplementedError + + async def init_weight_transfer( + self, init_request: WeightTransferInitRequest + ) -> None: + """Initialize weight transfer for RL training.""" + raise NotImplementedError + + async def update_weights(self, request: WeightUpdateRequest) -> None: + """Batched weight update for RL training.""" + raise NotImplementedError + + async def finalize_weight_update(self) -> None: + """Finalize the current weight update during RL training.""" + raise NotImplementedError diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 1a1874b7d6d2..8019018344a5 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -33,6 +33,10 @@ RunnerOption, TokenizerMode, ) +from vllm.distributed.weight_transfer.base import ( + WeightTransferInitRequest, + WeightUpdateRequest, +) from vllm.engine.arg_utils import EngineArgs from vllm.entrypoints.chat_utils import ( ChatCompletionMessageParam, @@ -1786,3 +1790,44 @@ def _run_engine( # This is necessary because some requests may be finished earlier than # its previous requests. return sorted(outputs, key=lambda x: int(x.request_id)) + + def init_weight_transfer(self, request: WeightTransferInitRequest) -> None: + """ + Initialize weight transfer for RL training. + + Args: + request: Weight transfer initialization request with backend-specific info + """ + + if isinstance(request, WeightTransferInitRequest): + init_info_dict = request.init_info + else: + raise TypeError(f"Expected WeightTransferInitRequest, got {type(request)}") + + self.llm_engine.collective_rpc( + "init_weight_transfer", kwargs={"init_info": init_info_dict} + ) + + def update_weights(self, request: WeightUpdateRequest) -> None: + """ + Update the weights of the model. + + Args: + request: Weight update request with backend-specific update info + """ + + if hasattr(request, "update_info"): + update_info_dict = request.update_info + else: + raise TypeError(f"Invalid `WeightUpdateRequest` format: {type(request)}") + + self.llm_engine.collective_rpc( + "update_weights", kwargs={"update_info": update_info_dict} + ) + + def finalize_weight_update(self) -> None: + """ + Finalize the weight update by processing weights for quantization/kernel format. + This should be called after all weight updates are complete. + """ + self.llm_engine.collective_rpc("finalize_weight_update") diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b4c46bb66e7c..f3eb56e687c3 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -31,6 +31,9 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send import vllm.envs as envs +from vllm.config import VllmConfig +from vllm.distributed.weight_transfer import WeightUpdateRequest +from vllm.distributed.weight_transfer.base import WeightTransferInitRequest from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import EngineClient from vllm.entrypoints.anthropic.protocol import ( @@ -638,6 +641,47 @@ async def create_translations( return StreamingResponse(content=generator, media_type="text/event-stream") +@router.post("/init_weight_transfer") +async def init_weight_transfer(raw_request: Request): + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904 + init_info = body.get("init_info") + if init_info is None: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail="Missing 'init_info' in request body", + ) + await engine_client(raw_request).init_weight_transfer( + WeightTransferInitRequest(init_info=init_info) + ) + return JSONResponse(content={"message": "Weight transfer initialized"}) + + +@router.post("/update_weights") +async def update_weights(raw_request: Request): + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904 + update_info = body.get("update_info") + if update_info is None: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail="Missing 'update_info' in request body", + ) + await engine_client(raw_request).update_weights( + request=WeightUpdateRequest(update_info=update_info) + ) + return JSONResponse(content={"message": "Weights updated"}) + + +@router.post("/finalize_weight_update") +async def finalize_weight_update(raw_request: Request): + await engine_client(raw_request).finalize_weight_update() + return JSONResponse(content={"message": "Weight update finalized"}) + def load_log_config(log_config_file: str | None) -> dict | None: if not log_config_file: return None diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 1c93d63fe49a..b9eb4b997c44 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -14,6 +14,10 @@ import vllm.envs as envs from vllm.config import VllmConfig +from vllm.distributed.weight_transfer.base import ( + WeightTransferInitRequest, + WeightUpdateRequest, +) from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import EngineClient from vllm.entrypoints.utils import _validate_truncation_size @@ -869,3 +873,46 @@ def errored(self) -> bool: @property def dead_error(self) -> BaseException: return EngineDeadError() + + async def init_weight_transfer(self, request: WeightTransferInitRequest) -> None: + """ + Initialize weight transfer for RL training. + + Args: + request: Weight transfer initialization request with backend-specific info + """ + from vllm.distributed.weight_transfer.base import ( + WeightTransferInitRequest, + ) + + if isinstance(request, WeightTransferInitRequest): + init_info_dict = request.init_info + else: + raise TypeError(f"Expected WeightTransferInitRequest, got {type(request)}") + + await self.collective_rpc( + "init_weight_transfer", kwargs={"init_info": init_info_dict} + ) + + async def update_weights(self, request: WeightUpdateRequest) -> None: + """ + Batched weight update for RL training. + + Args: + request: Weight update request with backend-specific update info + """ + + if hasattr(request, "update_info"): + update_info_dict = request.update_info + else: + raise TypeError(f"Invalid WeightUpdateRequest format: {type(request)}") + + await self.collective_rpc( + "update_weights", kwargs={"update_info": update_info_dict} + ) + + async def finalize_weight_update(self) -> None: + """ + Finalize the current weight update during RL training. + """ + await self.collective_rpc("finalize_weight_update") diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index fd4ee596c30e..649e7086c220 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -32,6 +32,9 @@ get_pp_group, get_tp_group, ) +from vllm.distributed.weight_transfer import ( + init_transfer_engine, +) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.models.interfaces import is_mixture_of_experts @@ -95,6 +98,18 @@ def __init__( # Buffers saved before sleep self._sleep_saved_buffers: dict[str, torch.Tensor] = {} + # Weight transfer engine (initialized on-demand) + # check if class is in the map + self.weight_transfer_engine = init_transfer_engine( + self.vllm_config.weight_transfer_config, self.vllm_config.parallel_config + ) + + # Weight transfer engine (initialized on-demand) + # check if class is in the map + self.weight_transfer_engine = init_transfer_engine( + self.vllm_config.weight_transfer_config, self.vllm_config.parallel_config + ) + # Torch/CUDA profiler. Enabled and configured through profiler_config. self.profiler: Any | None = None profiler_config = vllm_config.profiler_config @@ -922,12 +937,60 @@ def save_tensorized_model( tensorizer_config=tensorizer_config, ) + def init_weight_transfer(self, init_info: dict) -> None: + """ + Initialize weight transfer mechanism. + For NCCL backend, this creates a process group with the trainer. + + Args: + init_info: Dictionary containing backend-specific initialization info + """ + # Parse dict into backend-specific typed dataclass + typed_init_info = self.weight_transfer_engine.parse_init_info(init_info) + self.weight_transfer_engine.init_transfer(typed_init_info) + + def update_weights(self, update_info: dict) -> None: + """ + Batched weight update from the trainer. + + Args: + update_info: Dictionary containing backend-specific update info + """ + if self.weight_transfer_engine is None: + raise RuntimeError("Weight transfer not initialized.") + + # Parse dict into backend-specific typed dataclass + typed_update_info = self.weight_transfer_engine.parse_update_info(update_info) + + # Receive and load weights incrementally to avoid OOM + self.weight_transfer_engine.receive_weights( + typed_update_info, + load_weights=self.model_runner.model.load_weights, + ) + + def finalize_weight_update(self) -> None: + """ + Finalize the weight update by processing weights for quantization/kernel format. + This should be called after all weight updates are complete. + """ + from vllm.model_executor.model_loader.utils import process_weights_after_loading + + process_weights_after_loading( + self.model_runner.model, self.model_config, self.device + ) + def shutdown(self) -> None: if runner := getattr(self, "model_runner", None): runner.ensure_kv_transfer_shutdown() if self.profiler is not None: self.profiler.shutdown() + if weight_transfer_engine := getattr(self, "weight_transfer_engine", None): + weight_transfer_engine.shutdown() + + if weight_transfer_engine := getattr(self, "weight_transfer_engine", None): + weight_transfer_engine.shutdown() + def init_worker_distributed_environment( vllm_config: VllmConfig, From 037f968c6363293b0c703bbecd4611148d744b4c Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Wed, 7 Jan 2026 18:31:10 -0800 Subject: [PATCH 02/36] updated async, added world size endpoints Signed-off-by: ahao-anyscale --- .../new_weight_syncing/rlhf.py | 10 +++--- .../new_weight_syncing/rlhf_async_new_apis.py | 33 +++++++++---------- .../new_weight_syncing/rlhf_http.py | 26 +++++++++------ vllm/entrypoints/llm.py | 9 +++++ vllm/entrypoints/openai/api_server.py | 11 ++++++- 5 files changed, 56 insertions(+), 33 deletions(-) diff --git a/examples/offline_inference/new_weight_syncing/rlhf.py b/examples/offline_inference/new_weight_syncing/rlhf.py index 6f69caae2b84..c2861d3be4fd 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf.py +++ b/examples/offline_inference/new_weight_syncing/rlhf.py @@ -92,12 +92,11 @@ def __init__(self, *args, **kwargs): )(MyLLM).remote( model=MODEL_NAME, enforce_eager=True, - tensor_parallel_size=1, - data_parallel_size=2, + tensor_parallel_size=2, + data_parallel_size=1, distributed_executor_backend="ray", weight_transfer_config=WeightTransferConfig(backend="nccl"), enable_expert_parallel=True, - all2all_backend="pplx", ) # Generate text from the prompts. @@ -124,6 +123,7 @@ def __init__(self, *args, **kwargs): master_address = get_ip() master_port = get_open_port() +world_size = ray.get(llm.get_world_size.remote()) + 1 # +1 for the trainer handle = llm.init_weight_transfer.remote( WeightTransferInitRequest( init_info=asdict( @@ -131,14 +131,14 @@ def __init__(self, *args, **kwargs): master_address=master_address, master_port=master_port, rank_offset=1, - world_size=3, + world_size=world_size, ) ) ) ) model_update_group = stateless_init_process_group( - master_address, master_port, 0, 3, torch.device("cuda:0") + master_address, master_port, 0, world_size, torch.device("cuda:0") ) ray.get(handle) diff --git a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py index 0bd38428d5b5..819c592a8caa 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py +++ b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py @@ -47,6 +47,9 @@ ) from vllm.utils.network_utils import get_ip, get_open_port +# MODEL_NAME = "Qwen/Qwen3-30B-A3B-Thinking-2507" +MODEL_NAME = "facebook/opt-125m" + class MyLLM: """Simple wrapper over AsyncLLM for supporting async RL.""" @@ -81,18 +84,10 @@ async def generate_with_retry( async def abort_generation(self) -> None: self.generation_paused_event.set() - unfinished_request_ids = list( - self.engine.output_processor.request_states.keys() - ) - if unfinished_request_ids: - await self.engine.abort(unfinished_request_ids) - await self.engine.reset_prefix_cache() - print( - f"abort_generation() finished, aborted" - f"{len(unfinished_request_ids)} requests" - ) + return await self.engine.pause_generation(wait_for_inflight_requests=False) async def resume_generation(self) -> None: + await self.engine.resume_generation() self.generation_paused_event.clear() async def collective_rpc(self, method: str, args: tuple = ()): @@ -116,9 +111,7 @@ async def finalize_weight_update(self) -> None: # Load the OPT-125M model onto GPU 0 for the training workload. -train_model = AutoModelForCausalLM.from_pretrained( - "facebook/opt-125m", dtype=torch.bfloat16 -) +train_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16) train_model.to("cuda:0") # Initialize Ray and set the visible devices. The vLLM engine will @@ -150,7 +143,7 @@ async def finalize_weight_update(self) -> None: num_gpus=0, scheduling_strategy=scheduling_inference, )(MyLLM).remote( - model="facebook/opt-125m", + model=MODEL_NAME, enforce_eager=True, tensor_parallel_size=2, distributed_executor_backend="ray", @@ -159,13 +152,18 @@ async def finalize_weight_update(self) -> None: # Generate text from the prompts. prompts = [ - "Hello, my name is", + "My name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] -sampling_params = SamplingParams(temperature=0) +sampling_params = [ + SamplingParams(temperature=0, max_tokens=2), + SamplingParams(temperature=0, max_tokens=32), + SamplingParams(temperature=0, max_tokens=32), + SamplingParams(temperature=0, max_tokens=32), +] # Set up the communication channel between the training process and the # inference engine. @@ -191,7 +189,8 @@ async def finalize_weight_update(self) -> None: generation_futures = [ - llm.generate_with_retry.remote(prompt, sampling_params) for prompt in prompts + llm.generate_with_retry.remote(prompt, params) + for prompt, params in zip(prompts, sampling_params) ] finished, pending = ray.wait(generation_futures, num_returns=1) diff --git a/examples/offline_inference/new_weight_syncing/rlhf_http.py b/examples/offline_inference/new_weight_syncing/rlhf_http.py index cd321bb5f895..5a885b0076b6 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf_http.py +++ b/examples/offline_inference/new_weight_syncing/rlhf_http.py @@ -47,12 +47,6 @@ MODEL_NAME = "Qwen/Qwen3-30B-A3B-Thinking-2507" -INFERENCE_WORLD_SIZE = 2 -WORLD_SIZE = INFERENCE_WORLD_SIZE + 1 - -DEVICE = f"cuda:{INFERENCE_WORLD_SIZE}" - - def generate_completions(client: OpenAI, model: str, prompts: list[str]) -> list[str]: """Generate completions using the OpenAI-compatible API.""" results = [] @@ -118,11 +112,24 @@ def finalize_weight_update(base_url: str) -> None: response.raise_for_status() +def get_world_size(base_url: str) -> int: + """Get world size from the vLLM server.""" + url = f"{base_url}/get_world_size" + response = requests.get(url, timeout=10) + response.raise_for_status() + return response.json()["world_size"] + + def main(): + # Get the inference world size from the vLLM server + inference_world_size = get_world_size(BASE_URL) + world_size = inference_world_size + 1 # +1 for the trainer + device = f"cuda:{inference_world_size}" + # Load the training model print(f"Loading training model: {MODEL_NAME}") train_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16) - train_model.to(DEVICE) + train_model.to(device) # Create OpenAI client pointing to the vLLM server client = OpenAI( @@ -161,14 +168,13 @@ def main(): init_thread = threading.Thread( target=init_weight_transfer, - args=(BASE_URL, master_address, master_port, rank_offset, WORLD_SIZE), + args=(BASE_URL, master_address, master_port, rank_offset, world_size), ) init_thread.start() # Initialize NCCL process group on trainer side - device = torch.device(DEVICE) model_update_group = stateless_init_process_group( - master_address, master_port, 0, WORLD_SIZE, device + master_address, master_port, 0, world_size, torch.device(device) ) # Wait for init_weight_transfer to complete diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 8019018344a5..6d6d5b405884 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -355,6 +355,15 @@ def _make_config(value: Any, cls: type[_R]) -> _R: def get_tokenizer(self) -> TokenizerLike: return self.llm_engine.get_tokenizer() + def get_world_size(self) -> int: + """Get the world size from the parallel config. + + Returns: + The world size including data parallelism + (tensor_parallel_size * pipeline_parallel_size * data_parallel_size). + """ + return self.llm_engine.vllm_config.parallel_config.world_size_across_dp + def reset_mm_cache(self) -> None: self.input_processor.clear_mm_cache() self.llm_engine.reset_mm_cache() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index f3eb56e687c3..1a3eb4c79912 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -31,7 +31,6 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send import vllm.envs as envs -from vllm.config import VllmConfig from vllm.distributed.weight_transfer import WeightUpdateRequest from vllm.distributed.weight_transfer.base import WeightTransferInitRequest from vllm.engine.arg_utils import AsyncEngineArgs @@ -682,6 +681,16 @@ async def finalize_weight_update(raw_request: Request): await engine_client(raw_request).finalize_weight_update() return JSONResponse(content={"message": "Weight update finalized"}) + +@router.get("/get_world_size") +async def get_world_size(raw_request: Request): + """Get the world size from the parallel config (TP * PP * DP).""" + world_size = engine_client( + raw_request + ).vllm_config.parallel_config.world_size_across_dp + return JSONResponse(content={"world_size": world_size}) + + def load_log_config(log_config_file: str | None) -> dict | None: if not log_config_file: return None From 676c7e3d9cb1f1eeac0fa3ee4035c7ccbc59a553 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Thu, 8 Jan 2026 10:26:11 -0800 Subject: [PATCH 03/36] bugfixes Signed-off-by: ahao-anyscale --- vllm/distributed/weight_transfer/__init__.py | 2 +- vllm/v1/worker/gpu_worker.py | 9 --------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/vllm/distributed/weight_transfer/__init__.py b/vllm/distributed/weight_transfer/__init__.py index bd06049e3bc4..4f7ec6f4975b 100644 --- a/vllm/distributed/weight_transfer/__init__.py +++ b/vllm/distributed/weight_transfer/__init__.py @@ -44,7 +44,7 @@ def init_transfer_engine(config: WeightTransferConfig, parallel_config: Parallel "WeightTransferEngine", "NCCLWeightTransferEngine", "register_weight_transfer_engine", - "WEIGHT_TRANSFER_ENGINE_MAP", + "WEIGHT_TRANSFER_ENGINE_REGISTRY", "IPCWeightTransferEngine", "WeightUpdateRequest", ] diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 649e7086c220..39c75e3a9b24 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -104,12 +104,6 @@ def __init__( self.vllm_config.weight_transfer_config, self.vllm_config.parallel_config ) - # Weight transfer engine (initialized on-demand) - # check if class is in the map - self.weight_transfer_engine = init_transfer_engine( - self.vllm_config.weight_transfer_config, self.vllm_config.parallel_config - ) - # Torch/CUDA profiler. Enabled and configured through profiler_config. self.profiler: Any | None = None profiler_config = vllm_config.profiler_config @@ -988,9 +982,6 @@ def shutdown(self) -> None: if weight_transfer_engine := getattr(self, "weight_transfer_engine", None): weight_transfer_engine.shutdown() - if weight_transfer_engine := getattr(self, "weight_transfer_engine", None): - weight_transfer_engine.shutdown() - def init_worker_distributed_environment( vllm_config: VllmConfig, From 0a935dd7cd0a4bffa9456a703a5219985725dc83 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Thu, 8 Jan 2026 10:49:43 -0800 Subject: [PATCH 04/36] ipc fix Signed-off-by: ahao-anyscale --- vllm/distributed/weight_transfer/ipc_engine.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/weight_transfer/ipc_engine.py b/vllm/distributed/weight_transfer/ipc_engine.py index d13b3fab4a3b..758b7c06df53 100644 --- a/vllm/distributed/weight_transfer/ipc_engine.py +++ b/vllm/distributed/weight_transfer/ipc_engine.py @@ -86,8 +86,10 @@ def init_transfer(self, init_info: IPCInitInfo) -> None: pass def receive_weights( - self, update_info: IPCUpdateInfo - ) -> list[tuple[str, torch.Tensor]]: + self, + update_info: IPCUpdateInfo, + load_weights: Callable[[list[tuple[str, torch.Tensor]]], None], + ) -> None: """ Receive weights from the trainer via CUDA IPC handles. @@ -118,7 +120,7 @@ def receive_weights( weight = func(*list_args) # type: ignore weights.append((name, weight)) - return weights + load_weights(weights) def shutdown(self) -> None: """ From 1f29fcdc95397c9b34882a3c2414d3cdfd5f9343 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Thu, 8 Jan 2026 15:25:07 -0800 Subject: [PATCH 05/36] moved rlhf scripts Signed-off-by: ahao-anyscale --- examples/offline_inference/{ => weight_syncing/legacy}/rlhf.py | 0 .../{ => weight_syncing/legacy}/rlhf_colocate.py | 0 .../{ => weight_syncing/legacy}/rlhf_online_quant.py | 0 .../offline_inference/{ => weight_syncing/legacy}/rlhf_utils.py | 0 .../{new_weight_syncing => weight_syncing}/rlhf.py | 0 .../{new_weight_syncing => weight_syncing}/rlhf_async_new_apis.py | 0 .../{new_weight_syncing => weight_syncing}/rlhf_http.py | 0 .../{new_weight_syncing => weight_syncing}/rlhf_ipc.py | 0 .../{new_weight_syncing => weight_syncing}/rlhf_utils.py | 0 9 files changed, 0 insertions(+), 0 deletions(-) rename examples/offline_inference/{ => weight_syncing/legacy}/rlhf.py (100%) rename examples/offline_inference/{ => weight_syncing/legacy}/rlhf_colocate.py (100%) rename examples/offline_inference/{ => weight_syncing/legacy}/rlhf_online_quant.py (100%) rename examples/offline_inference/{ => weight_syncing/legacy}/rlhf_utils.py (100%) rename examples/offline_inference/{new_weight_syncing => weight_syncing}/rlhf.py (100%) rename examples/offline_inference/{new_weight_syncing => weight_syncing}/rlhf_async_new_apis.py (100%) rename examples/offline_inference/{new_weight_syncing => weight_syncing}/rlhf_http.py (100%) rename examples/offline_inference/{new_weight_syncing => weight_syncing}/rlhf_ipc.py (100%) rename examples/offline_inference/{new_weight_syncing => weight_syncing}/rlhf_utils.py (100%) diff --git a/examples/offline_inference/rlhf.py b/examples/offline_inference/weight_syncing/legacy/rlhf.py similarity index 100% rename from examples/offline_inference/rlhf.py rename to examples/offline_inference/weight_syncing/legacy/rlhf.py diff --git a/examples/offline_inference/rlhf_colocate.py b/examples/offline_inference/weight_syncing/legacy/rlhf_colocate.py similarity index 100% rename from examples/offline_inference/rlhf_colocate.py rename to examples/offline_inference/weight_syncing/legacy/rlhf_colocate.py diff --git a/examples/offline_inference/rlhf_online_quant.py b/examples/offline_inference/weight_syncing/legacy/rlhf_online_quant.py similarity index 100% rename from examples/offline_inference/rlhf_online_quant.py rename to examples/offline_inference/weight_syncing/legacy/rlhf_online_quant.py diff --git a/examples/offline_inference/rlhf_utils.py b/examples/offline_inference/weight_syncing/legacy/rlhf_utils.py similarity index 100% rename from examples/offline_inference/rlhf_utils.py rename to examples/offline_inference/weight_syncing/legacy/rlhf_utils.py diff --git a/examples/offline_inference/new_weight_syncing/rlhf.py b/examples/offline_inference/weight_syncing/rlhf.py similarity index 100% rename from examples/offline_inference/new_weight_syncing/rlhf.py rename to examples/offline_inference/weight_syncing/rlhf.py diff --git a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py b/examples/offline_inference/weight_syncing/rlhf_async_new_apis.py similarity index 100% rename from examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py rename to examples/offline_inference/weight_syncing/rlhf_async_new_apis.py diff --git a/examples/offline_inference/new_weight_syncing/rlhf_http.py b/examples/offline_inference/weight_syncing/rlhf_http.py similarity index 100% rename from examples/offline_inference/new_weight_syncing/rlhf_http.py rename to examples/offline_inference/weight_syncing/rlhf_http.py diff --git a/examples/offline_inference/new_weight_syncing/rlhf_ipc.py b/examples/offline_inference/weight_syncing/rlhf_ipc.py similarity index 100% rename from examples/offline_inference/new_weight_syncing/rlhf_ipc.py rename to examples/offline_inference/weight_syncing/rlhf_ipc.py diff --git a/examples/offline_inference/new_weight_syncing/rlhf_utils.py b/examples/offline_inference/weight_syncing/rlhf_utils.py similarity index 100% rename from examples/offline_inference/new_weight_syncing/rlhf_utils.py rename to examples/offline_inference/weight_syncing/rlhf_utils.py From 3471efbb039652f56f44cd7903b7d4e85b47e31a Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Thu, 8 Jan 2026 18:02:51 -0800 Subject: [PATCH 06/36] added tests Signed-off-by: ahao-anyscale --- .../weight_syncing/rlhf_ipc.py | 4 - tests/distributed/test_weight_transfer.py | 441 ++++++++++++++++++ tests/entrypoints/weight_transfer/__init__.py | 3 + .../test_weight_transfer_llm.py | 394 ++++++++++++++++ vllm/config/weight_transfer.py | 2 +- .../weight_transfer/nccl_engine.py | 3 +- vllm/engine/arg_utils.py | 17 + 7 files changed, 858 insertions(+), 6 deletions(-) create mode 100644 tests/distributed/test_weight_transfer.py create mode 100644 tests/entrypoints/weight_transfer/__init__.py create mode 100644 tests/entrypoints/weight_transfer/test_weight_transfer_llm.py diff --git a/examples/offline_inference/weight_syncing/rlhf_ipc.py b/examples/offline_inference/weight_syncing/rlhf_ipc.py index 7ed0a4dbf9af..1d7bff638b25 100644 --- a/examples/offline_inference/weight_syncing/rlhf_ipc.py +++ b/examples/offline_inference/weight_syncing/rlhf_ipc.py @@ -118,10 +118,6 @@ def zero_data(self): ray.init(runtime_env={"excludes": [".git/objects/pack/"]}) -# Create a placement group that reserves GPU 1–2 for the vLLM inference engine. -# Learn more about Ray placement groups: -# https://docs.ray.io/en/latest/placement-groups.html - pg_colocate = placement_group([{"GPU": 1, "CPU": 0}]) ray.get(pg_colocate.ready()) diff --git a/tests/distributed/test_weight_transfer.py b/tests/distributed/test_weight_transfer.py new file mode 100644 index 000000000000..8fb33723fd76 --- /dev/null +++ b/tests/distributed/test_weight_transfer.py @@ -0,0 +1,441 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for weight transfer engine backends. + +Unit tests for engine classes (parsing, validation, registry). +Integration test for NCCL weight transfer between processes using Ray. +""" + +from dataclasses import dataclass +from unittest.mock import MagicMock + +import pytest +import ray +import torch + +from vllm.config.parallel import ParallelConfig +from vllm.config.weight_transfer import WeightTransferConfig +from vllm.distributed.weight_transfer import ( + WEIGHT_TRANSFER_ENGINE_REGISTRY, + init_transfer_engine, + register_weight_transfer_engine, +) +from vllm.distributed.weight_transfer.base import ( + BackendInitInfo, + WeightTransferEngine, +) +from vllm.distributed.weight_transfer.ipc_engine import ( + IPCInitInfo, + IPCUpdateInfo, + IPCWeightTransferEngine, +) +from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLInitInfo, + NCCLUpdateInfo, + NCCLWeightTransferEngine, +) +from vllm.utils.network_utils import get_open_port + + +def create_mock_parallel_config( + rank: int = 0, + world_size: int = 1, + dp_rank: int = 0, +) -> ParallelConfig: + """Create a mock ParallelConfig for testing.""" + config = MagicMock(spec=ParallelConfig) + config.rank = rank + config.world_size = world_size + config.data_parallel_rank = dp_rank + return config + + +# --- Unit Tests: NCCLUpdateInfo Validation --- + + +class TestNCCLUpdateInfoValidation: + """Test NCCLUpdateInfo dataclass validation.""" + + def test_valid_update_info(self): + """Test creating valid NCCLUpdateInfo.""" + info = NCCLUpdateInfo( + names=["layer.weight", "layer.bias"], + dtype_names=["float32", "float32"], + shapes=[[10, 10], [10]], + ) + assert info.names == ["layer.weight", "layer.bias"] + assert info.dtype_names == ["float32", "float32"] + assert info.shapes == [[10, 10], [10]] + + def test_mismatched_dtype_names_raises(self): + """Test that mismatched dtype_names length raises ValueError.""" + with pytest.raises(ValueError, match="dtype_names"): + NCCLUpdateInfo( + names=["layer.weight", "layer.bias"], + dtype_names=["float32"], # Only one dtype + shapes=[[10, 10], [10]], + ) + + def test_mismatched_shapes_raises(self): + """Test that mismatched shapes length raises ValueError.""" + with pytest.raises(ValueError, match="shapes"): + NCCLUpdateInfo( + names=["layer.weight", "layer.bias"], + dtype_names=["float32", "float32"], + shapes=[[10, 10]], # Only one shape + ) + + def test_empty_lists_valid(self): + """Test that empty lists are valid.""" + info = NCCLUpdateInfo( + names=[], + dtype_names=[], + shapes=[], + ) + assert len(info.names) == 0 + + +# --- Unit Tests: IPCUpdateInfo Validation --- + + +class TestIPCUpdateInfoValidation: + """Test IPCUpdateInfo dataclass validation.""" + + def test_valid_update_info(self): + """Test creating valid IPCUpdateInfo.""" + info = IPCUpdateInfo( + names=["layer.weight"], + dtype_names=["float32"], + shapes=[[10, 10]], + ipc_handles=[{"gpu-uuid": (lambda: None, ())}], + ) + assert info.names == ["layer.weight"] + + def test_mismatched_ipc_handles_raises(self): + """Test that mismatched ipc_handles length raises ValueError.""" + with pytest.raises(ValueError, match="ipc_handles"): + IPCUpdateInfo( + names=["layer.weight", "layer.bias"], + dtype_names=["float32", "float32"], + shapes=[[10, 10], [10]], + ipc_handles=[{}], # Only one handle + ) + + +# --- Unit Tests: Engine Parsing --- + + +class TestNCCLEngineParsing: + """Test NCCLWeightTransferEngine parsing methods.""" + + def test_parse_init_info_valid(self): + """Test parsing valid init info dict.""" + config = WeightTransferConfig(backend="nccl") + parallel_config = create_mock_parallel_config() + engine = NCCLWeightTransferEngine(config, parallel_config) + + init_info = engine.parse_init_info( + { + "master_address": "127.0.0.1", + "master_port": 12345, + "rank_offset": 1, + "world_size": 3, + } + ) + + assert isinstance(init_info, NCCLInitInfo) + assert init_info.master_address == "127.0.0.1" + assert init_info.master_port == 12345 + assert init_info.rank_offset == 1 + assert init_info.world_size == 3 + + def test_parse_init_info_missing_field_raises(self): + """Test parsing init info with missing required field.""" + config = WeightTransferConfig(backend="nccl") + parallel_config = create_mock_parallel_config() + engine = NCCLWeightTransferEngine(config, parallel_config) + + with pytest.raises(ValueError, match="Invalid init_info"): + engine.parse_init_info( + { + "master_address": "127.0.0.1", + # Missing master_port, rank_offset, world_size + } + ) + + def test_parse_update_info_valid(self): + """Test parsing valid update info dict.""" + config = WeightTransferConfig(backend="nccl") + parallel_config = create_mock_parallel_config() + engine = NCCLWeightTransferEngine(config, parallel_config) + + update_info = engine.parse_update_info( + { + "names": ["w1", "w2"], + "dtype_names": ["float32", "bfloat16"], + "shapes": [[100, 100], [50]], + } + ) + + assert isinstance(update_info, NCCLUpdateInfo) + assert update_info.names == ["w1", "w2"] + assert update_info.dtype_names == ["float32", "bfloat16"] + assert update_info.shapes == [[100, 100], [50]] + + +class TestIPCEngineParsing: + """Test IPCWeightTransferEngine parsing methods.""" + + def test_parse_init_info_empty(self): + """Test parsing empty init info (IPC doesn't need init params).""" + config = WeightTransferConfig(backend="ipc") + parallel_config = create_mock_parallel_config() + engine = IPCWeightTransferEngine(config, parallel_config) + + init_info = engine.parse_init_info({}) + assert isinstance(init_info, IPCInitInfo) + + def test_init_transfer_is_noop(self): + """Test that IPC init_transfer is a no-op.""" + config = WeightTransferConfig(backend="ipc") + parallel_config = create_mock_parallel_config() + engine = IPCWeightTransferEngine(config, parallel_config) + + # Should not raise + engine.init_transfer(IPCInitInfo()) + + +# --- Unit Tests: Engine Registry --- + + +class TestEngineRegistry: + """Test weight transfer engine registry.""" + + def test_init_transfer_engine_nccl(self): + """Test init_transfer_engine creates NCCL engine.""" + config = WeightTransferConfig(backend="nccl") + parallel_config = create_mock_parallel_config() + engine = init_transfer_engine(config, parallel_config) + assert isinstance(engine, NCCLWeightTransferEngine) + + def test_init_transfer_engine_ipc(self): + """Test init_transfer_engine creates IPC engine.""" + config = WeightTransferConfig(backend="ipc") + parallel_config = create_mock_parallel_config() + engine = init_transfer_engine(config, parallel_config) + assert isinstance(engine, IPCWeightTransferEngine) + + def test_init_transfer_engine_invalid_backend(self): + """Test init_transfer_engine raises for invalid backend.""" + config = WeightTransferConfig(backend="invalid") + parallel_config = create_mock_parallel_config() + with pytest.raises(ValueError, match="Invalid weight transfer backend"): + init_transfer_engine(config, parallel_config) + + def test_register_custom_engine(self): + """Test registering a custom engine.""" + + @dataclass + class CustomInitInfo(BackendInitInfo): + pass + + class CustomEngine(WeightTransferEngine): + init_info_cls = CustomInitInfo + update_info_cls = NCCLUpdateInfo # Reuse for simplicity + + def init_transfer(self, init_info): + pass + + def receive_weights(self, update_info, load_weights): + pass + + def shutdown(self): + pass + + # Register custom engine + register_weight_transfer_engine("custom_test", CustomEngine) + assert "custom_test" in WEIGHT_TRANSFER_ENGINE_REGISTRY + + # Clean up + del WEIGHT_TRANSFER_ENGINE_REGISTRY["custom_test"] + + def test_register_duplicate_raises(self): + """Test registering duplicate engine name raises.""" + with pytest.raises(ValueError, match="already registered"): + register_weight_transfer_engine("nccl", NCCLWeightTransferEngine) + + +# --- Test receive_weights without init raises --- + + +def test_nccl_receive_weights_without_init_raises(): + """Test that receive_weights raises if init_transfer wasn't called.""" + if torch.cuda.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + config = WeightTransferConfig(backend="nccl") + parallel_config = create_mock_parallel_config() + engine = NCCLWeightTransferEngine(config, parallel_config) + + update_info = NCCLUpdateInfo( + names=["w"], + dtype_names=["float32"], + shapes=[[10]], + ) + + with pytest.raises(RuntimeError, match="not initialized"): + engine.receive_weights(update_info, lambda x: None) + + +# --- Integration Test: NCCL Weight Transfer Between Ray Tasks --- + + +@ray.remote(num_gpus=1) +def trainer_broadcast_tensor( + master_address: str, + master_port: int, + world_size: int, + tensor_shape: list[int], + tensor_dtype: str, +) -> bool: + """Trainer task that broadcasts a tensor via NCCL.""" + import torch + + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.utils import StatelessProcessGroup + + # Create process group as rank 0 (trainer) + pg = StatelessProcessGroup.create( + host=master_address, + port=master_port, + rank=0, + world_size=world_size, + ) + # Ray sets CUDA_VISIBLE_DEVICES, so device 0 is the assigned GPU + comm = PyNcclCommunicator(pg, device=0) + + # Create and broadcast the tensor + dtype = getattr(torch, tensor_dtype) + tensor_to_send = torch.ones(tensor_shape, dtype=dtype, device="cuda:0") + comm.broadcast(tensor_to_send, src=0, stream=torch.cuda.current_stream()) + torch.cuda.synchronize() + + return True + + +@ray.remote(num_gpus=1) +def inference_receive_tensor( + master_address: str, + master_port: int, + world_size: int, + tensor_shape: list[int], + tensor_dtype: str, +) -> dict: + """Inference task that receives tensor via NCCLWeightTransferEngine.""" + from unittest.mock import MagicMock + + import torch + + from vllm.config.parallel import ParallelConfig + from vllm.config.weight_transfer import WeightTransferConfig + from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLInitInfo, + NCCLUpdateInfo, + NCCLWeightTransferEngine, + ) + + # Create engine with mock parallel config + config = WeightTransferConfig(backend="nccl") + parallel_config = MagicMock(spec=ParallelConfig) + parallel_config.rank = 0 + parallel_config.world_size = 1 + parallel_config.data_parallel_rank = 0 + + engine = NCCLWeightTransferEngine(config, parallel_config) + + # Initialize the engine (joins as rank 1) + init_info = NCCLInitInfo( + master_address=master_address, + master_port=master_port, + rank_offset=1, # Trainer is rank 0, we become rank 1 + world_size=world_size, + ) + engine.init_transfer(init_info) + + # Receive weights with a no-op load_weights that captures the tensor + received_tensors = [] + + def noop_load_weights(weights: list[tuple[str, torch.Tensor]]): + for name, tensor in weights: + # Clone tensor to keep it after engine cleans up + received_tensors.append((name, tensor.clone())) + + update_info = NCCLUpdateInfo( + names=["test.weight"], + dtype_names=[tensor_dtype], + shapes=[tensor_shape], + ) + engine.receive_weights(update_info, noop_load_weights) + torch.cuda.synchronize() + + # Verify we received the tensor + success = False + received_shape = None + received_sum = None + + if len(received_tensors) == 1: + name, tensor = received_tensors[0] + received_shape = list(tensor.shape) + received_sum = tensor.sum().item() + # Check shape matches and values are all 1s (trainer sends ones) + if received_shape == tensor_shape: + expected_sum = 1.0 * torch.tensor(tensor_shape).prod().item() + if abs(received_sum - expected_sum) < 0.01: + success = True + + engine.shutdown() + + return { + "success": success, + "received_shape": received_shape, + "received_sum": received_sum, + } + + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run NCCL weight transfer test.", +) +def test_nccl_weight_transfer_between_processes(): + """Test NCCL weight transfer from trainer to inference process using Ray. + + This test verifies that the NCCLWeightTransferEngine can receive + tensors broadcast by a trainer process via NCCL. + """ + ray.init(ignore_reinit_error=True) + + master_address = "127.0.0.1" + master_port = get_open_port() + world_size = 2 # 1 trainer + 1 inference worker + + # Tensor to transfer: 100x100 ones + tensor_shape = [100, 100] + tensor_dtype = "float32" + + # Start both tasks concurrently - Ray assigns GPUs automatically + inference_future = inference_receive_tensor.remote( + master_address, master_port, world_size, tensor_shape, tensor_dtype + ) + trainer_future = trainer_broadcast_tensor.remote( + master_address, master_port, world_size, tensor_shape, tensor_dtype + ) + + # Wait for both to complete + trainer_result, result = ray.get([trainer_future, inference_future]) + + assert trainer_result, "Trainer should complete successfully" + assert result["success"], ( + f"Weight transfer failed. " + f"Received shape: {result['received_shape']}, " + f"Received sum: {result['received_sum']}" + ) diff --git a/tests/entrypoints/weight_transfer/__init__.py b/tests/entrypoints/weight_transfer/__init__.py new file mode 100644 index 000000000000..6655f8913623 --- /dev/null +++ b/tests/entrypoints/weight_transfer/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + diff --git a/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py b/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py new file mode 100644 index 000000000000..3cb0d7d3ff44 --- /dev/null +++ b/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py @@ -0,0 +1,394 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for weight transfer APIs via LLM class. + +These tests use a mock weight transfer engine to verify that the API +calls the correct methods with the right arguments, without requiring +actual NCCL communication. +""" + +import os +from collections.abc import Callable +from dataclasses import dataclass +from unittest.mock import patch + +import pytest +import torch + +from vllm import LLM +from vllm.config import WeightTransferConfig +from vllm.distributed.weight_transfer.base import ( + BackendInitInfo, + BackendUpdateInfo, + WeightTransferEngine, + WeightTransferInitRequest, + WeightUpdateRequest, +) + +from ...utils import create_new_process_for_each_test + +# Use a tiny model for fast testing +MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM" + + +# --- Mock Weight Transfer Engine --- + + +@dataclass +class MockInitInfo(BackendInitInfo): + """Mock initialization info.""" + + test_param: str = "test" + + +@dataclass +class MockUpdateInfo(BackendUpdateInfo): + """Mock update info.""" + + names: list[str] | None = None + dtype_names: list[str] | None = None + shapes: list[list[int]] | None = None + + +class MockWeightTransferEngine(WeightTransferEngine[MockInitInfo, MockUpdateInfo]): + """Mock weight transfer engine that tracks method calls.""" + + init_info_cls = MockInitInfo + update_info_cls = MockUpdateInfo + + # Class-level tracking for verification across processes + init_transfer_called: bool = False + receive_weights_called: bool = False + shutdown_called: bool = False + last_init_info: MockInitInfo | None = None + last_update_info: MockUpdateInfo | None = None + + def __init__(self, config, parallel_config): + super().__init__(config, parallel_config) + # Reset tracking on init + MockWeightTransferEngine.init_transfer_called = False + MockWeightTransferEngine.receive_weights_called = False + MockWeightTransferEngine.shutdown_called = False + MockWeightTransferEngine.last_init_info = None + MockWeightTransferEngine.last_update_info = None + + def init_transfer(self, init_info: MockInitInfo) -> None: + MockWeightTransferEngine.init_transfer_called = True + MockWeightTransferEngine.last_init_info = init_info + + def receive_weights( + self, + update_info: MockUpdateInfo, + load_weights: Callable[[list[tuple[str, torch.Tensor]]], None], + ) -> None: + MockWeightTransferEngine.receive_weights_called = True + MockWeightTransferEngine.last_update_info = update_info + # Simulate loading weights by calling load_weights with empty list + # (In real implementation, this would receive and load actual weights) + load_weights([]) + + def shutdown(self) -> None: + MockWeightTransferEngine.shutdown_called = True + + +def mock_init_transfer_engine(config, parallel_config): + """Factory function that returns our mock engine.""" + return MockWeightTransferEngine(config, parallel_config) + + +# --- Tests --- + + +@create_new_process_for_each_test() +def test_get_world_size_tp1(): + """Test world_size is correctly configured for TP=1.""" + if torch.cuda.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + llm = LLM( + model=MODEL_NAME, + enforce_eager=True, + load_format="dummy", + tensor_parallel_size=1, + weight_transfer_config=WeightTransferConfig(backend="nccl"), + ) + + world_size = llm.llm_engine.vllm_config.parallel_config.world_size + assert world_size == 1 + + +@create_new_process_for_each_test() +def test_get_world_size_tp2(): + """Test world_size is correctly configured for TP=2.""" + if torch.cuda.device_count() < 2: + pytest.skip("Need at least 2 GPUs for this test") + + llm = LLM( + model=MODEL_NAME, + enforce_eager=True, + load_format="dummy", + tensor_parallel_size=2, + distributed_executor_backend="mp", + weight_transfer_config=WeightTransferConfig(backend="nccl"), + ) + + world_size = llm.llm_engine.vllm_config.parallel_config.world_size + assert world_size == 2 + + +@create_new_process_for_each_test() +def test_init_weight_transfer_calls_engine(): + """Test that init_weight_transfer calls the engine's init_transfer method.""" + if torch.cuda.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + # Enable insecure serialization to allow pickling functions for collective_rpc + os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" + + with patch( + "vllm.distributed.weight_transfer.init_transfer_engine", + mock_init_transfer_engine, + ): + llm = LLM( + model=MODEL_NAME, + enforce_eager=True, + load_format="dummy", + tensor_parallel_size=1, + weight_transfer_config=WeightTransferConfig(backend="nccl"), + ) + + # Verify engine was created + def check_engine_exists(self): + return self.weight_transfer_engine is not None + + results = llm.collective_rpc(check_engine_exists) + assert all(results), "Weight transfer engine should be initialized" + + # Call init_weight_transfer + llm.init_weight_transfer( + WeightTransferInitRequest(init_info={"test_param": "hello"}) + ) + + # Verify init_transfer was called on the engine + def check_init_called(self): + engine = self.weight_transfer_engine + return ( + engine.init_transfer_called, + engine.last_init_info.test_param if engine.last_init_info else None, + ) + + results = llm.collective_rpc(check_init_called) + for called, param in results: + assert called, "init_transfer should have been called" + assert param == "hello", f"Expected 'hello', got {param}" + + +@create_new_process_for_each_test() +def test_update_weights_calls_engine(): + """Test that update_weights calls the engine's receive_weights method.""" + if torch.cuda.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + # Enable insecure serialization to allow pickling functions for collective_rpc + os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" + + with patch( + "vllm.distributed.weight_transfer.init_transfer_engine", + mock_init_transfer_engine, + ): + llm = LLM( + model=MODEL_NAME, + enforce_eager=True, + load_format="dummy", + tensor_parallel_size=1, + weight_transfer_config=WeightTransferConfig(backend="nccl"), + ) + + # First init the weight transfer + llm.init_weight_transfer( + WeightTransferInitRequest(init_info={"test_param": "init"}) + ) + + # Call update_weights + test_names = ["layer.weight", "layer.bias"] + test_dtypes = ["float32", "float32"] + test_shapes = [[10, 10], [10]] + + llm.update_weights( + WeightUpdateRequest( + update_info={ + "names": test_names, + "dtype_names": test_dtypes, + "shapes": test_shapes, + } + ) + ) + + # Verify receive_weights was called with correct info + def check_update_called(self): + engine = self.weight_transfer_engine + if not engine.receive_weights_called: + return False, None, None, None + info = engine.last_update_info + return (True, info.names, info.dtype_names, info.shapes) + + results = llm.collective_rpc(check_update_called) + for called, names, dtypes, shapes in results: + assert called, "receive_weights should have been called" + assert names == test_names + assert dtypes == test_dtypes + assert shapes == test_shapes + + +@create_new_process_for_each_test() +def test_finalize_weight_update_runs(): + """Test that finalize_weight_update completes without error.""" + if torch.cuda.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + with patch( + "vllm.distributed.weight_transfer.init_transfer_engine", + mock_init_transfer_engine, + ): + llm = LLM( + model=MODEL_NAME, + enforce_eager=True, + load_format="dummy", + tensor_parallel_size=1, + weight_transfer_config=WeightTransferConfig(backend="nccl"), + ) + + # finalize_weight_update should run without error + # (it calls process_weights_after_loading internally) + llm.finalize_weight_update() + + +@create_new_process_for_each_test() +def test_full_weight_transfer_flow(): + """Test the complete weight transfer flow: init -> update -> finalize.""" + if torch.cuda.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + # Enable insecure serialization to allow pickling functions for collective_rpc + os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" + + with patch( + "vllm.distributed.weight_transfer.init_transfer_engine", + mock_init_transfer_engine, + ): + llm = LLM( + model=MODEL_NAME, + enforce_eager=True, + load_format="dummy", + tensor_parallel_size=1, + weight_transfer_config=WeightTransferConfig(backend="nccl"), + ) + + # Step 1: Initialize + llm.init_weight_transfer( + WeightTransferInitRequest(init_info={"test_param": "flow_test"}) + ) + + # Step 2: Update weights + llm.update_weights( + WeightUpdateRequest( + update_info={ + "names": ["test.weight"], + "dtype_names": ["bfloat16"], + "shapes": [[100, 100]], + } + ) + ) + + # Step 3: Finalize + llm.finalize_weight_update() + + # Verify the full flow completed + def check_flow(self): + engine = self.weight_transfer_engine + return { + "init_called": engine.init_transfer_called, + "update_called": engine.receive_weights_called, + "init_param": ( + engine.last_init_info.test_param if engine.last_init_info else None + ), + "update_names": ( + engine.last_update_info.names if engine.last_update_info else None + ), + } + + results = llm.collective_rpc(check_flow) + for result in results: + assert result["init_called"], "init_transfer should be called" + assert result["update_called"], "receive_weights should be called" + assert result["init_param"] == "flow_test" + assert result["update_names"] == ["test.weight"] + + +@pytest.mark.parametrize("tp_size", [1, 2]) +@create_new_process_for_each_test() +def test_weight_transfer_with_tp(tp_size): + """Test weight transfer works correctly with tensor parallelism.""" + if torch.cuda.device_count() < tp_size: + pytest.skip(f"Need at least {tp_size} GPUs for this test") + + # Enable insecure serialization to allow pickling functions for collective_rpc + os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" + + with patch( + "vllm.distributed.weight_transfer.init_transfer_engine", + mock_init_transfer_engine, + ): + llm = LLM( + model=MODEL_NAME, + enforce_eager=True, + load_format="dummy", + tensor_parallel_size=tp_size, + distributed_executor_backend="ray", + weight_transfer_config=WeightTransferConfig(backend="nccl"), + ) + + # Run weight transfer + llm.init_weight_transfer( + WeightTransferInitRequest(init_info={"test_param": "tp_test"}) + ) + + llm.update_weights( + WeightUpdateRequest( + update_info={ + "names": ["w"], + "dtype_names": ["float16"], + "shapes": [[50]], + } + ) + ) + + llm.finalize_weight_update() + + # Verify all TP ranks processed the weight transfer + def check_processed(self): + engine = self.weight_transfer_engine + return engine.init_transfer_called and engine.receive_weights_called + + results = llm.collective_rpc(check_processed) + assert len(results) == tp_size + assert all(results), f"All {tp_size} workers should process weight transfer" + + +@create_new_process_for_each_test() +def test_weight_transfer_config_backend(): + """Test that WeightTransferConfig backend is properly configured.""" + if torch.cuda.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + # Test with nccl backend + llm = LLM( + model=MODEL_NAME, + enforce_eager=True, + load_format="dummy", + tensor_parallel_size=1, + weight_transfer_config=WeightTransferConfig(backend="nccl"), + ) + + config = llm.llm_engine.vllm_config.weight_transfer_config + assert config.backend == "nccl" diff --git a/vllm/config/weight_transfer.py b/vllm/config/weight_transfer.py index 370652e2cbb5..06cd0d68aed7 100644 --- a/vllm/config/weight_transfer.py +++ b/vllm/config/weight_transfer.py @@ -11,5 +11,5 @@ class WeightTransferConfig: """Configuration for weight transfer during RL training.""" - backend: Literal["nccl", "ipc", "rdma"] = "nccl" + backend: Literal["nccl", "ipc"] = "nccl" """The backend to use for weight transfer.""" diff --git a/vllm/distributed/weight_transfer/nccl_engine.py b/vllm/distributed/weight_transfer/nccl_engine.py index d10c7c9b36c4..230ae8615960 100644 --- a/vllm/distributed/weight_transfer/nccl_engine.py +++ b/vllm/distributed/weight_transfer/nccl_engine.py @@ -147,4 +147,5 @@ def receive_weights( def shutdown(self) -> None: if self.model_update_group is not None: - self.model_update_group.destroy() + # Clean up the communicator by removing the reference + self.model_update_group = None diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c78494c9a975..5383b0965545 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -578,6 +578,9 @@ class EngineArgs: ) tokens_only: bool = False + weight_transfer_backend: str = "nccl" + """Backend for weight transfer during RL training. Options: nccl, ipc""" + weight_transfer_config: WeightTransferConfig = field( default_factory=WeightTransferConfig ) @@ -592,6 +595,11 @@ def __post_init__(self): self.attention_config = AttentionConfig(**self.attention_config) if isinstance(self.eplb_config, dict): self.eplb_config = EPLBConfig(**self.eplb_config) + # Handle weight_transfer_backend CLI arg + if self.weight_transfer_backend is not None: + self.weight_transfer_config = WeightTransferConfig( + backend=self.weight_transfer_backend + ) # Setup plugins from vllm.plugins import load_general_plugins @@ -1167,6 +1175,15 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: vllm_group.add_argument( "--optimization-level", **vllm_kwargs["optimization_level"] ) + vllm_group.add_argument( + "--weight-transfer-backend", + type=str, + choices=["nccl", "ipc"], + default="nccl", + help="Backend for weight transfer during RL training. " + "Options: nccl (distributed), ipc (same-node shared memory)" + "Default: nccl when enabled.", + ) # Other arguments parser.add_argument( From 26249ebe7b506de3b807fbdfad27a844aa7d2225 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Thu, 8 Jan 2026 18:14:50 -0800 Subject: [PATCH 07/36] precommit fix Signed-off-by: ahao-anyscale --- vllm/distributed/weight_transfer/nccl_engine.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/weight_transfer/nccl_engine.py b/vllm/distributed/weight_transfer/nccl_engine.py index 230ae8615960..6d7d731e4fc9 100644 --- a/vllm/distributed/weight_transfer/nccl_engine.py +++ b/vllm/distributed/weight_transfer/nccl_engine.py @@ -4,9 +4,13 @@ from collections.abc import Callable from dataclasses import dataclass +from typing import TYPE_CHECKING import torch +if TYPE_CHECKING: + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.config.parallel import ParallelConfig from vllm.config.weight_transfer import WeightTransferConfig from vllm.distributed.weight_transfer.base import ( @@ -72,7 +76,7 @@ def __init__( parallel_config: The configuration for the parallel setup """ super().__init__(config, parallel_config) - self.model_update_group = None + self.model_update_group: PyNcclCommunicator | None = None def init_transfer(self, init_info: NCCLInitInfo) -> None: """ From 0d4e296cd522ae29ab530d49c91546c0c7eb2a2d Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Mon, 12 Jan 2026 18:00:22 -0800 Subject: [PATCH 08/36] added packed tensors Signed-off-by: ahao-anyscale --- .../offline_inference/weight_syncing/rlhf.py | 21 +- .../weight_syncing/rlhf_http.py | 21 +- tests/distributed/test_packed_tensor.py | 476 ++++++++++++++++++ .../weight_transfer/nccl_engine.py | 116 ++++- .../weight_transfer/packed_tensor.py | 202 ++++++++ vllm/engine/arg_utils.py | 5 +- 6 files changed, 807 insertions(+), 34 deletions(-) create mode 100644 tests/distributed/test_packed_tensor.py create mode 100644 vllm/distributed/weight_transfer/packed_tensor.py diff --git a/examples/offline_inference/weight_syncing/rlhf.py b/examples/offline_inference/weight_syncing/rlhf.py index c2861d3be4fd..473e6a9a840d 100644 --- a/examples/offline_inference/weight_syncing/rlhf.py +++ b/examples/offline_inference/weight_syncing/rlhf.py @@ -41,10 +41,15 @@ WeightTransferInitRequest, WeightUpdateRequest, ) -from vllm.distributed.weight_transfer.nccl_engine import NCCLInitInfo, NCCLUpdateInfo +from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLInitInfo, + NCCLUpdateInfo, + NCCLWeightTransferEngine, +) from vllm.utils.network_utils import get_ip, get_open_port MODEL_NAME = "Qwen/Qwen3-30B-A3B-Thinking-2507" +# MODEL_NAME = "facebook/opt-125m" class MyLLM(LLM): @@ -57,6 +62,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) +torch.cuda.set_device("cuda:0") # Load the OPT-125M model onto GPU 0 for the training workload. train_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16) train_model.to("cuda:0") @@ -96,7 +102,7 @@ def __init__(self, *args, **kwargs): data_parallel_size=1, distributed_executor_backend="ray", weight_transfer_config=WeightTransferConfig(backend="nccl"), - enable_expert_parallel=True, + # enable_expert_parallel=True, ) # Generate text from the prompts. @@ -159,6 +165,7 @@ def __init__(self, *args, **kwargs): shapes.append(p.shape) # Issue update_weights call with NCCL-specific update info +# packed=True enables efficient batched tensor broadcasting handle = llm.update_weights.remote( WeightUpdateRequest( update_info=asdict( @@ -166,14 +173,18 @@ def __init__(self, *args, **kwargs): names=names, dtype_names=dtype_names, shapes=shapes, + packed=True, ) ) ) ) -# Broadcast all weights from trainer -for name, p in train_model.named_parameters(): - model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream()) +# Broadcast all weights from trainer using the weight transfer API +NCCLWeightTransferEngine.trainer_broadcast_weights( + iterator=train_model.named_parameters(), + group=model_update_group, + packed=True, +) ray.get(handle) diff --git a/examples/offline_inference/weight_syncing/rlhf_http.py b/examples/offline_inference/weight_syncing/rlhf_http.py index 5a885b0076b6..2c402ecb8ce6 100644 --- a/examples/offline_inference/weight_syncing/rlhf_http.py +++ b/examples/offline_inference/weight_syncing/rlhf_http.py @@ -40,11 +40,15 @@ from rlhf_utils import stateless_init_process_group from transformers import AutoModelForCausalLM -from vllm.distributed.weight_transfer.nccl_engine import NCCLInitInfo, NCCLUpdateInfo +from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLInitInfo, + NCCLUpdateInfo, + NCCLWeightTransferEngine, +) from vllm.utils.network_utils import get_ip, get_open_port BASE_URL = "http://localhost:8000" -MODEL_NAME = "Qwen/Qwen3-30B-A3B-Thinking-2507" +MODEL_NAME = "facebook/opt-125m" def generate_completions(client: OpenAI, model: str, prompts: list[str]) -> list[str]: @@ -89,6 +93,7 @@ def update_weights( names: list[str], dtype_names: list[str], shapes: list[list[int]], + packed: bool = False, ) -> None: """Update weights via HTTP endpoint.""" url = f"{base_url}/update_weights" @@ -98,6 +103,7 @@ def update_weights( names=names, dtype_names=dtype_names, shapes=shapes, + packed=packed, ) ) } @@ -125,6 +131,7 @@ def main(): inference_world_size = get_world_size(BASE_URL) world_size = inference_world_size + 1 # +1 for the trainer device = f"cuda:{inference_world_size}" + torch.cuda.set_device(device) # Load the training model print(f"Loading training model: {MODEL_NAME}") @@ -198,16 +205,20 @@ def main(): # Start the update_weights call in a separate thread since it will block # waiting for NCCL broadcasts + # packed=True enables efficient batched tensor broadcasting update_thread = threading.Thread( target=update_weights, - args=(BASE_URL, names, dtype_names, shapes), + args=(BASE_URL, names, dtype_names, shapes, True), # packed=True ) update_thread.start() # Broadcast all weights from trainer to vLLM workers print("Broadcasting weights via NCCL...") - for name, p in train_model.named_parameters(): - model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream()) + NCCLWeightTransferEngine.trainer_broadcast_weights( + iterator=train_model.named_parameters(), + group=model_update_group, + packed=True, + ) # Wait for update_weights to complete update_thread.join() diff --git a/tests/distributed/test_packed_tensor.py b/tests/distributed/test_packed_tensor.py new file mode 100644 index 000000000000..fc6ff91e4920 --- /dev/null +++ b/tests/distributed/test_packed_tensor.py @@ -0,0 +1,476 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for packed tensor broadcasting functionality. + +Unit tests for packed_broadcast_producer and packed_broadcast_consumer. +These utilities enable efficient batched tensor transfer over NCCL. +""" + +from unittest.mock import patch + +import pytest +import torch + +from vllm.distributed.weight_transfer.nccl_engine import NCCLUpdateInfo + + +class MockCommunicationGroup: + """Mock communication group for testing producer broadcast operations.""" + + def __init__(self): + self.broadcasted_tensors: list[torch.Tensor] = [] + self.broadcast_count = 0 + self.device = torch.device("cuda:0") + + def broadcast(self, tensor, src): + """Mock broadcast that stores the tensor for later verification.""" + self.broadcasted_tensors.append(tensor.clone()) + self.broadcast_count += 1 + + +class MockConsumerCommunicationGroup: + """Mock communication group for consumer that returns pre-stored tensors.""" + + def __init__(self, tensors_to_return: list[torch.Tensor]): + self.tensors_to_return = tensors_to_return + self.current_index = 0 + self.device = torch.device("cuda:0") + + def broadcast(self, tensor, src): + """Mock broadcast that fills the tensor with pre-stored data.""" + if self.current_index < len(self.tensors_to_return): + tensor.copy_(self.tensors_to_return[self.current_index]) + self.current_index += 1 + + +def create_mock_model_params( + num_layers: int = 3, + dtype: torch.dtype = torch.float32, +) -> list[tuple[str, torch.Tensor]]: + """Create mock model parameters for testing.""" + params = [] + for i in range(num_layers): + params.append((f"layer{i}.weight", torch.randn(10, 20, dtype=dtype))) + params.append((f"layer{i}.bias", torch.randn(10, dtype=dtype))) + return params + + +def create_state_dict_info( + params: list[tuple[str, torch.Tensor]], +) -> dict[str, tuple[tuple[int, ...], torch.dtype]]: + """Create state dict info (name -> (shape, dtype)) from params.""" + return {name: (tuple(tensor.shape), tensor.dtype) for name, tensor in params} + + +# --- Unit Tests: NCCLUpdateInfo packed field --- + + +class TestNCCLUpdateInfoPacked: + """Test NCCLUpdateInfo dataclass packed field.""" + + def test_packed_default_false(self): + """Test that packed defaults to False.""" + info = NCCLUpdateInfo( + names=["layer.weight"], + dtype_names=["float32"], + shapes=[[10, 10]], + ) + assert info.packed is False + + def test_packed_can_be_set_true(self): + """Test that packed can be set to True.""" + info = NCCLUpdateInfo( + names=["layer.weight"], + dtype_names=["float32"], + shapes=[[10, 10]], + packed=True, + ) + assert info.packed is True + + +# --- Unit Tests: packed_broadcast_producer --- + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestPackedBroadcastProducer: + """Test packed_broadcast_producer function.""" + + def test_producer_broadcasts_tensors(self): + """Test that producer broadcasts all tensors.""" + from vllm.distributed.weight_transfer.packed_tensor import ( + packed_broadcast_producer, + ) + + params = create_mock_model_params() + params_cuda = [(name, tensor.cuda()) for name, tensor in params] + + mock_group = MockCommunicationGroup() + + # Use a small target size to force multiple batches + with patch( + "vllm.distributed.weight_transfer.packed_tensor.get_target_packed_tensor_size", + return_value=500, + ): + packed_broadcast_producer( + iterator=iter(params_cuda), + group=mock_group, + src=0, + post_iter_func=lambda x: x[1], + ) + + # Should have broadcasted some tensors + assert mock_group.broadcast_count > 0 + assert len(mock_group.broadcasted_tensors) > 0 + + def test_producer_single_large_tensor(self): + """Test with a single tensor larger than target size.""" + from vllm.distributed.weight_transfer.packed_tensor import ( + packed_broadcast_producer, + ) + + # Create a large tensor + large_tensor = torch.randn(1000, 1000, dtype=torch.float32).cuda() + params = [("large_weight", large_tensor)] + + mock_group = MockCommunicationGroup() + + # Small target size to force the tensor to exceed it + with patch( + "vllm.distributed.weight_transfer.packed_tensor.get_target_packed_tensor_size", + return_value=100, + ): + packed_broadcast_producer( + iterator=iter(params), + group=mock_group, + src=0, + post_iter_func=lambda x: x[1], + ) + + # Should still broadcast the tensor (at least 1 broadcast) + assert mock_group.broadcast_count >= 1 + assert len(mock_group.broadcasted_tensors) >= 1 + + # Verify the total broadcasted size matches the tensor + expected_size = large_tensor.numel() * large_tensor.element_size() + actual_size = sum(t.numel() for t in mock_group.broadcasted_tensors) + assert actual_size == expected_size + + def test_producer_multiple_batches(self): + """Test that tensors are properly batched when exceeding target size.""" + from vllm.distributed.weight_transfer.packed_tensor import ( + packed_broadcast_producer, + ) + + # Create many small tensors + params = [ + (f"weight_{i}", torch.randn(10, 10, dtype=torch.float32).cuda()) + for i in range(20) + ] + + mock_group = MockCommunicationGroup() + + # Small target size to force multiple batches + with patch( + "vllm.distributed.weight_transfer.packed_tensor.get_target_packed_tensor_size", + return_value=2000, + ): + packed_broadcast_producer( + iterator=iter(params), + group=mock_group, + src=0, + post_iter_func=lambda x: x[1], + ) + + # Should have multiple broadcasts + assert mock_group.broadcast_count > 1 + + # Total size should match sum of all tensors + expected_total = sum(t.numel() * t.element_size() for _, t in params) + actual_total = sum(t.numel() for t in mock_group.broadcasted_tensors) + assert actual_total == expected_total + + def test_producer_empty_iterator(self): + """Test producer handles empty iterator gracefully.""" + from vllm.distributed.weight_transfer.packed_tensor import ( + packed_broadcast_producer, + ) + + mock_group = MockCommunicationGroup() + + with patch( + "vllm.distributed.weight_transfer.packed_tensor.get_target_packed_tensor_size", + return_value=1000, + ): + packed_broadcast_producer( + iterator=iter([]), + group=mock_group, + src=0, + post_iter_func=lambda x: x[1], + ) + + # No broadcasts for empty iterator + assert mock_group.broadcast_count == 0 + + +# --- Unit Tests: packed_broadcast_consumer --- + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestPackedBroadcastConsumer: + """Test packed_broadcast_consumer function.""" + + def test_consumer_receives_tensors(self): + """Test that consumer receives and unpacks tensors.""" + from vllm.distributed.weight_transfer.packed_tensor import ( + packed_broadcast_consumer, + packed_broadcast_producer, + ) + + params = create_mock_model_params() + params_cuda = [(name, tensor.cuda()) for name, tensor in params] + + # First, run producer to get the broadcasted tensors + producer_group = MockCommunicationGroup() + + with patch( + "vllm.distributed.weight_transfer.packed_tensor.get_target_packed_tensor_size", + return_value=2000, + ): + packed_broadcast_producer( + iterator=iter(params_cuda), + group=producer_group, + src=0, + post_iter_func=lambda x: x[1], + ) + + # Now run consumer with the broadcasted tensors + consumer_group = MockConsumerCommunicationGroup( + producer_group.broadcasted_tensors + ) + + state_dict_info = create_state_dict_info(params_cuda) + + unpacked_tensors = {} + + def post_unpack_func(tensor_list): + for name, tensor in tensor_list: + unpacked_tensors[name] = tensor.clone() + + packed_broadcast_consumer( + iterator=iter(state_dict_info.items()), + group=consumer_group, + src=0, + post_unpack_func=post_unpack_func, + ) + + # Verify all parameters were unpacked + assert len(unpacked_tensors) == len(params) + + # Verify each tensor matches the original + for name, original_tensor in params_cuda: + assert name in unpacked_tensors + unpacked = unpacked_tensors[name] + assert unpacked.shape == original_tensor.shape + assert unpacked.dtype == original_tensor.dtype + assert torch.allclose(unpacked, original_tensor, rtol=1e-5, atol=1e-7) + + +# --- Integration Tests: Producer-Consumer Roundtrip --- + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestPackedBroadcastRoundtrip: + """Test producer-consumer roundtrip behavior.""" + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) + def test_roundtrip_different_dtypes(self, dtype): + """Test roundtrip with different data types.""" + from vllm.distributed.weight_transfer.packed_tensor import ( + packed_broadcast_consumer, + packed_broadcast_producer, + ) + + params = create_mock_model_params(num_layers=2, dtype=dtype) + params_cuda = [(name, tensor.cuda()) for name, tensor in params] + + producer_group = MockCommunicationGroup() + + with patch( + "vllm.distributed.weight_transfer.packed_tensor.get_target_packed_tensor_size", + return_value=1000, + ): + packed_broadcast_producer( + iterator=iter(params_cuda), + group=producer_group, + src=0, + post_iter_func=lambda x: x[1], + ) + + consumer_group = MockConsumerCommunicationGroup( + producer_group.broadcasted_tensors + ) + + state_dict_info = create_state_dict_info(params_cuda) + unpacked_tensors = {} + + def post_unpack_func(tensor_list): + for name, tensor in tensor_list: + unpacked_tensors[name] = tensor.clone() + + packed_broadcast_consumer( + iterator=iter(state_dict_info.items()), + group=consumer_group, + src=0, + post_unpack_func=post_unpack_func, + ) + + # Verify roundtrip preserves data + for name, original_tensor in params_cuda: + assert name in unpacked_tensors + unpacked = unpacked_tensors[name] + assert unpacked.dtype == dtype + assert torch.allclose(unpacked, original_tensor, rtol=1e-4, atol=1e-6) + + def test_roundtrip_mixed_dtypes(self): + """Test roundtrip with mixed data types.""" + from vllm.distributed.weight_transfer.packed_tensor import ( + packed_broadcast_consumer, + packed_broadcast_producer, + ) + + # Create params with mixed dtypes + params = [ + ("layer1.weight", torch.randn(10, 20, dtype=torch.float32).cuda()), + ("layer1.bias", torch.randn(10, dtype=torch.float16).cuda()), + ("layer2.weight", torch.randn(20, 30, dtype=torch.bfloat16).cuda()), + ] + + producer_group = MockCommunicationGroup() + + with patch( + "vllm.distributed.weight_transfer.packed_tensor.get_target_packed_tensor_size", + return_value=500, + ): + packed_broadcast_producer( + iterator=iter(params), + group=producer_group, + src=0, + post_iter_func=lambda x: x[1], + ) + + consumer_group = MockConsumerCommunicationGroup( + producer_group.broadcasted_tensors + ) + + state_dict_info = create_state_dict_info(params) + unpacked_tensors = {} + + def post_unpack_func(tensor_list): + for name, tensor in tensor_list: + unpacked_tensors[name] = tensor.clone() + + packed_broadcast_consumer( + iterator=iter(state_dict_info.items()), + group=consumer_group, + src=0, + post_unpack_func=post_unpack_func, + ) + + # Verify all params roundtrip correctly with correct dtypes + for name, original_tensor in params: + assert name in unpacked_tensors + unpacked = unpacked_tensors[name] + assert unpacked.shape == original_tensor.shape + assert unpacked.dtype == original_tensor.dtype + assert torch.allclose(unpacked, original_tensor, rtol=1e-4, atol=1e-6) + + @pytest.mark.parametrize("target_size", [100, 1000, 10000, 100000]) + def test_roundtrip_different_batch_sizes(self, target_size): + """Test roundtrip with different target batch sizes.""" + from vllm.distributed.weight_transfer.packed_tensor import ( + packed_broadcast_consumer, + packed_broadcast_producer, + ) + + params = create_mock_model_params(num_layers=5) + params_cuda = [(name, tensor.cuda()) for name, tensor in params] + + producer_group = MockCommunicationGroup() + + with patch( + "vllm.distributed.weight_transfer.packed_tensor.get_target_packed_tensor_size", + return_value=target_size, + ): + packed_broadcast_producer( + iterator=iter(params_cuda), + group=producer_group, + src=0, + post_iter_func=lambda x: x[1], + ) + + consumer_group = MockConsumerCommunicationGroup( + producer_group.broadcasted_tensors + ) + + state_dict_info = create_state_dict_info(params_cuda) + unpacked_tensors = {} + + def post_unpack_func(tensor_list): + for name, tensor in tensor_list: + unpacked_tensors[name] = tensor.clone() + + packed_broadcast_consumer( + iterator=iter(state_dict_info.items()), + group=consumer_group, + src=0, + post_unpack_func=post_unpack_func, + ) + + # Verify all params roundtrip correctly + assert len(unpacked_tensors) == len(params) + for name, original_tensor in params_cuda: + assert name in unpacked_tensors + assert torch.allclose( + unpacked_tensors[name], original_tensor, rtol=1e-5, atol=1e-7 + ) + + +# --- Unit Tests: get_target_packed_tensor_size --- + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestGetTargetPackedTensorSize: + """Test get_target_packed_tensor_size function.""" + + def test_returns_positive_value(self): + """Test that function returns a positive value.""" + from vllm.distributed.weight_transfer.packed_tensor import ( + get_target_packed_tensor_size, + ) + + # Clear cache to get fresh value + get_target_packed_tensor_size.cache_clear() + size = get_target_packed_tensor_size() + assert size > 0 + + def test_is_cached(self): + """Test that function result is cached.""" + from vllm.distributed.weight_transfer.packed_tensor import ( + get_target_packed_tensor_size, + ) + + get_target_packed_tensor_size.cache_clear() + size1 = get_target_packed_tensor_size() + size2 = get_target_packed_tensor_size() + assert size1 == size2 + + def test_respects_max_buffer_size(self): + """Test that function respects max buffer size.""" + from vllm.distributed.weight_transfer.packed_tensor import ( + REFIT_MAX_BUFFER_SIZE, + get_target_packed_tensor_size, + ) + + get_target_packed_tensor_size.cache_clear() + size = get_target_packed_tensor_size() + assert size <= REFIT_MAX_BUFFER_SIZE diff --git a/vllm/distributed/weight_transfer/nccl_engine.py b/vllm/distributed/weight_transfer/nccl_engine.py index 6d7d731e4fc9..09f7f01b34f9 100644 --- a/vllm/distributed/weight_transfer/nccl_engine.py +++ b/vllm/distributed/weight_transfer/nccl_engine.py @@ -2,9 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """NCCL-based weight transfer engine.""" -from collections.abc import Callable +from collections.abc import Callable, Iterator from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import torch @@ -37,6 +37,10 @@ class NCCLUpdateInfo(BackendUpdateInfo): names: list[str] dtype_names: list[str] shapes: list[list[int]] + packed: bool = False + """Whether to use packed tensor broadcasting for efficiency. + When True, multiple tensors are batched together before broadcasting + to reduce NCCL communication overhead.""" def __post_init__(self): """Validate that all lists have the same length.""" @@ -119,37 +123,107 @@ def receive_weights( """ Receive weights from trainer via NCCL broadcast and load them incrementally. + If update_info.packed is True, uses packed tensor broadcasting for + efficient transfer of multiple weights in batches. Otherwise, uses simple + one-by-one broadcasting. + Args: - update_info: NCCL update info containing parameter names, dtypes, and shapes + update_info: NCCL update info containing parameter names, dtypes, shapes, + and packed flag load_weights: Callable that loads weights into the model. Called - incrementally for each weight to avoid OOM. + incrementally for each batch of weights to avoid OOM. """ if self.model_update_group is None: raise RuntimeError( "NCCL weight transfer not initialized. Call init_transfer() first." ) - for name, dtype_name, shape in zip( - update_info.names, update_info.dtype_names, update_info.shapes - ): - # Get the torch dtype - dtype = getattr(torch, dtype_name) - - # Allocate buffer for receiving weight - weight = torch.empty(shape, dtype=dtype, device="cuda") - - # Broadcast from rank 0 (trainer) - self.model_update_group.broadcast( - weight, src=0, stream=torch.cuda.current_stream() + if update_info.packed: + # Use packed tensor broadcasting for efficiency + from vllm.distributed.weight_transfer.packed_tensor import ( + packed_broadcast_consumer, ) - # Load weight immediately to avoid accumulating all weights in memory - load_weights([(name, weight)]) - - # Clean up the weight tensor - del weight + # Build iterator of (name, (shape, dtype)) from update_info + def state_dict_info_iterator(): + for name, dtype_name, shape in zip( + update_info.names, update_info.dtype_names, update_info.shapes + ): + dtype = getattr(torch, dtype_name) + yield (name, (shape, dtype)) + + packed_broadcast_consumer( + iterator=state_dict_info_iterator(), + group=self.model_update_group, + src=0, + post_unpack_func=load_weights, + ) + else: + # Use simple one-by-one broadcasting + for name, dtype_name, shape in zip( + update_info.names, update_info.dtype_names, update_info.shapes + ): + dtype = getattr(torch, dtype_name) + weight = torch.empty(shape, dtype=dtype, device="cuda") + self.model_update_group.broadcast( + weight, src=0, stream=torch.cuda.current_stream() + ) + load_weights([(name, weight)]) + del weight def shutdown(self) -> None: if self.model_update_group is not None: # Clean up the communicator by removing the reference self.model_update_group = None + + @staticmethod + def trainer_broadcast_weights( + iterator: Iterator[tuple[str, torch.Tensor]], + group: Any, + src: int = 0, + post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor] + | None = None, + packed: bool = False, + ) -> None: + """Broadcast weights from trainer to vLLM workers. + + Args: + iterator: Iterator of model parameters. Returns (name, tensor) tuples + group: Process group (PyNcclCommunicator) + src: Source rank (default 0, trainer is typically rank 0) + post_iter_func: Optional function to apply to each (name, tensor) pair + before broadcasting. If None, extracts just the tensor. + packed: Whether to use packed tensor broadcasting for efficiency. + When True, multiple tensors are batched together before + broadcasting to reduce NCCL communication overhead. + + Example: + >>> from vllm.distributed.weight_transfer.nccl_engine import ( + ... NCCLWeightTransferEngine, + ... ) + >>> param_iter = ((n, p) for n, p in model.named_parameters()) + >>> NCCLWeightTransferEngine.trainer_broadcast_weights( + ... param_iter, group, packed=True + ... ) + """ + if post_iter_func is None: + # Default: extract just the tensor from (name, tensor) tuple + post_iter_func = lambda x: x[1] + + if packed: + # Use packed tensor broadcasting for efficiency + from vllm.distributed.weight_transfer.packed_tensor import ( + packed_broadcast_producer, + ) + + packed_broadcast_producer( + iterator=iterator, + group=group, + src=src, + post_iter_func=post_iter_func, + ) + else: + # Use simple one-by-one broadcasting + for item in iterator: + tensor = post_iter_func(item) + group.broadcast(tensor, src=src, stream=torch.cuda.current_stream()) diff --git a/vllm/distributed/weight_transfer/packed_tensor.py b/vllm/distributed/weight_transfer/packed_tensor.py new file mode 100644 index 000000000000..b1afdc124599 --- /dev/null +++ b/vllm/distributed/weight_transfer/packed_tensor.py @@ -0,0 +1,202 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Packed tensor utilities for efficient weight transfer.""" + +import math +from collections.abc import Callable, Iterator +from functools import lru_cache +from typing import Any + +import torch + +# Configuration constants (can be overridden via environment variables later) +REFIT_BUFFER_MEMORY_RATIO = 0.02 +REFIT_NUM_BUFFERS = 2 +REFIT_MAX_BUFFER_SIZE = 5 * 1024**3 # 5GB max + + +@lru_cache(maxsize=1) +def get_target_packed_tensor_size() -> int: + """Calculate target packed tensor size based on GPU memory.""" + device = torch.device("cuda") + props = torch.cuda.get_device_properties(device) + total_memory_bytes = props.total_memory + target_size = min( + int(total_memory_bytes * REFIT_BUFFER_MEMORY_RATIO), REFIT_MAX_BUFFER_SIZE + ) + return target_size + + +def packed_broadcast_producer( + iterator: Iterator[tuple[str, torch.Tensor]], + group: Any, + src: int, + post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor], +) -> None: + """Broadcast tensors in a packed manner from trainer to workers. + + Args: + iterator: Iterator of model parameters. Returns a tuple of (name, tensor) + group: Process group (PyNcclCommunicator) + src: Source rank (0 in current implementation) + post_iter_func: Function to apply to each (name, tensor) pair before + packing, should return a tensor + + """ + target_packed_tensor_size = get_target_packed_tensor_size() + num_buffers = REFIT_NUM_BUFFERS + + streams = [torch.cuda.Stream() for _ in range(num_buffers)] + buffer_idx = 0 + + packing_tensor_list: list[list[torch.Tensor]] = [[] for _ in range(num_buffers)] + packing_tensor_sizes: list[int] = [0 for _ in range(num_buffers)] + packed_tensors: list[torch.Tensor] = [ + torch.empty(0, dtype=torch.uint8, device="cuda") for _ in range(num_buffers) + ] + + while True: + # Move to the next buffer + buffer_idx = (buffer_idx + 1) % num_buffers + # Synchronize the current stream + streams[buffer_idx].synchronize() + # Start tasks for the new buffer in a new stream + with torch.cuda.stream(streams[buffer_idx]): + try: + # Initialize the packing tensor list and sizes + packing_tensor_list[buffer_idx] = [] + packing_tensor_sizes[buffer_idx] = 0 + # Pack the tensors + while True: + # Apply post processing and convert to linearized uint8 tensor + tensor = post_iter_func(next(iterator)).view(torch.uint8).view(-1) + packing_tensor_list[buffer_idx].append(tensor) + packing_tensor_sizes[buffer_idx] += tensor.numel() + if packing_tensor_sizes[buffer_idx] > target_packed_tensor_size: + break + # Pack the tensors and call broadcast collective + packed_tensors[buffer_idx] = torch.cat( + packing_tensor_list[buffer_idx], dim=0 + ) + group.broadcast(packed_tensors[buffer_idx], src=src) + except StopIteration: + # Do the last broadcast if there are remaining tensors + if len(packing_tensor_list[buffer_idx]) > 0: + packed_tensors[buffer_idx] = torch.cat( + packing_tensor_list[buffer_idx], dim=0 + ) + group.broadcast(packed_tensors[buffer_idx], src=src) + break + + +def packed_broadcast_consumer( + iterator: Iterator[tuple[str, tuple[list[int], torch.dtype]]], + group: Any, + src: int, + post_unpack_func: Callable[[list[tuple[str, torch.Tensor]]], None], +) -> None: + """Consume packed tensors and unpack them into a list of tensors. + + Args: + iterator: Iterator of parameter metadata. Returns (name, (shape, dtype)) + group: Process group (PyNcclCommunicator) + src: Source rank (0 in current implementation) + post_unpack_func: Function to apply to each list of (name, tensor) after + unpacking + + """ + + def unpack_tensor( + packed_tensor: torch.Tensor, + meta_data_list: list[tuple[str, list[int], torch.dtype, int, int]], + ) -> list[tuple[str, torch.Tensor]]: + """Unpack a single tensor into a list of tensors. + + Args: + packed_tensor: The packed torch.uint8 tensor to unpack + meta_data_list: List[(name, shape, dtype, offset, tensor_size)] + + Returns: + unpacked List[(name, tensor)] + """ + # Perform batched split with torch.split_with_sizes + packed_tensor_sizes = [meta[4] for meta in meta_data_list] + unpacked_tensors = packed_tensor.split_with_sizes(packed_tensor_sizes) + + unpacked_list = [ + ( + meta_data_list[i][0], + tensor.view(meta_data_list[i][2]).view(*meta_data_list[i][1]), + ) + for i, tensor in enumerate(unpacked_tensors) + ] + + return unpacked_list + + target_packed_tensor_size = get_target_packed_tensor_size() + num_buffers = REFIT_NUM_BUFFERS + + streams = [torch.cuda.Stream() for _ in range(num_buffers)] + buffer_idx = 0 + + packing_tensor_meta_data: list[ + list[tuple[str, list[int], torch.dtype, int, int]] + ] = [[] for _ in range(num_buffers)] + packing_tensor_sizes: list[int] = [0 for _ in range(num_buffers)] + offsets: list[int] = [0 for _ in range(num_buffers)] + packed_tensors: list[torch.Tensor] = [ + torch.empty(0, dtype=torch.uint8, device="cuda") for _ in range(num_buffers) + ] + + while True: + # Move to the next buffer + buffer_idx = (buffer_idx + 1) % num_buffers + # Synchronize the current stream + streams[buffer_idx].synchronize() + with torch.cuda.stream(streams[buffer_idx]): + # Initialize the packing tensor meta data + packing_tensor_meta_data[buffer_idx] = [] + packing_tensor_sizes[buffer_idx] = 0 + offsets[buffer_idx] = 0 + try: + # Form a packed tensor + while True: + name, (shape, dtype) = next(iterator) + tensor_size = math.prod(shape) * dtype.itemsize + packing_tensor_meta_data[buffer_idx].append( + (name, shape, dtype, offsets[buffer_idx], tensor_size) + ) + packing_tensor_sizes[buffer_idx] += tensor_size + offsets[buffer_idx] += tensor_size + if packing_tensor_sizes[buffer_idx] > target_packed_tensor_size: + break + # Create a packed tensor and broadcast it + packed_tensors[buffer_idx] = torch.empty( + packing_tensor_sizes[buffer_idx], dtype=torch.uint8, device="cuda" + ) + group.broadcast(packed_tensors[buffer_idx], src=src) + # Load the packed tensor into the model + post_unpack_func( + unpack_tensor( + packed_tensors[buffer_idx], + packing_tensor_meta_data[buffer_idx], + ) + ) + except StopIteration: + # Do the last broadcast if there are remaining tensors + if len(packing_tensor_meta_data[buffer_idx]) > 0: + # Create a packed tensor and broadcast it + packed_tensors[buffer_idx] = torch.empty( + packing_tensor_sizes[buffer_idx], + dtype=torch.uint8, + device="cuda", + ) + group.broadcast(packed_tensors[buffer_idx], src=src) + # Load the packed tensor into the model + post_unpack_func( + unpack_tensor( + packed_tensors[buffer_idx], + packing_tensor_meta_data[buffer_idx], + ) + ) + break diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5383b0965545..c9b792005ac8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -595,10 +595,9 @@ def __post_init__(self): self.attention_config = AttentionConfig(**self.attention_config) if isinstance(self.eplb_config, dict): self.eplb_config = EPLBConfig(**self.eplb_config) - # Handle weight_transfer_backend CLI arg - if self.weight_transfer_backend is not None: + if isinstance(self.weight_transfer_config, dict): self.weight_transfer_config = WeightTransferConfig( - backend=self.weight_transfer_backend + **self.weight_transfer_config ) # Setup plugins from vllm.plugins import load_general_plugins From dae69463063359a2f77e7904593d4d8d886abb6d Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Mon, 12 Jan 2026 18:09:51 -0800 Subject: [PATCH 09/36] added unit tests to CI Signed-off-by: ahao-anyscale --- .buildkite/test-pipeline.yaml | 2 ++ .../test_weight_transfer_llm.py | 34 ++++--------------- 2 files changed, 9 insertions(+), 27 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 44165b9ba52f..921e49e95df0 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -1094,6 +1094,8 @@ steps: - pytest -v -s distributed/test_shm_broadcast.py - pytest -v -s distributed/test_shm_buffer.py - pytest -v -s distributed/test_shm_storage.py + - pytest -v -s distributed/test_packed_tensor.py + - pytest -v -s distributed/test_weight_transfer.py - label: 2 Node Tests (4 GPUs in total) # 16min timeout_in_minutes: 30 diff --git a/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py b/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py index 3cb0d7d3ff44..9f1d655c5dfe 100644 --- a/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py +++ b/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py @@ -117,25 +117,6 @@ def test_get_world_size_tp1(): assert world_size == 1 -@create_new_process_for_each_test() -def test_get_world_size_tp2(): - """Test world_size is correctly configured for TP=2.""" - if torch.cuda.device_count() < 2: - pytest.skip("Need at least 2 GPUs for this test") - - llm = LLM( - model=MODEL_NAME, - enforce_eager=True, - load_format="dummy", - tensor_parallel_size=2, - distributed_executor_backend="mp", - weight_transfer_config=WeightTransferConfig(backend="nccl"), - ) - - world_size = llm.llm_engine.vllm_config.parallel_config.world_size - assert world_size == 2 - - @create_new_process_for_each_test() def test_init_weight_transfer_calls_engine(): """Test that init_weight_transfer calls the engine's init_transfer method.""" @@ -325,12 +306,11 @@ def check_flow(self): assert result["update_names"] == ["test.weight"] -@pytest.mark.parametrize("tp_size", [1, 2]) @create_new_process_for_each_test() -def test_weight_transfer_with_tp(tp_size): +def test_weight_transfer_with_tp(): """Test weight transfer works correctly with tensor parallelism.""" - if torch.cuda.device_count() < tp_size: - pytest.skip(f"Need at least {tp_size} GPUs for this test") + if torch.cuda.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") # Enable insecure serialization to allow pickling functions for collective_rpc os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" @@ -343,7 +323,7 @@ def test_weight_transfer_with_tp(tp_size): model=MODEL_NAME, enforce_eager=True, load_format="dummy", - tensor_parallel_size=tp_size, + tensor_parallel_size=1, distributed_executor_backend="ray", weight_transfer_config=WeightTransferConfig(backend="nccl"), ) @@ -365,14 +345,14 @@ def test_weight_transfer_with_tp(tp_size): llm.finalize_weight_update() - # Verify all TP ranks processed the weight transfer + # Verify the worker processed the weight transfer def check_processed(self): engine = self.weight_transfer_engine return engine.init_transfer_called and engine.receive_weights_called results = llm.collective_rpc(check_processed) - assert len(results) == tp_size - assert all(results), f"All {tp_size} workers should process weight transfer" + assert len(results) == 1 + assert all(results), "Worker should process weight transfer" @create_new_process_for_each_test() From 784921105b6f5be21013e14f5f7bc202c3cc9f62 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Tue, 13 Jan 2026 11:42:35 -0800 Subject: [PATCH 10/36] added env variables, fixes Signed-off-by: ahao-anyscale --- .gitignore | 4 +-- .../weight_transfer/packed_tensor.py | 26 ++++--------------- vllm/engine/arg_utils.py | 4 +-- vllm/entrypoints/llm.py | 4 +-- vllm/envs.py | 10 +++++++ vllm/v1/engine/async_llm.py | 4 +-- 6 files changed, 23 insertions(+), 29 deletions(-) diff --git a/.gitignore b/.gitignore index 9d09506ebd84..7cda86478664 100644 --- a/.gitignore +++ b/.gitignore @@ -98,7 +98,7 @@ ipython_config.py **/generated/** # uv -# uv.lock +uv.lock # pyenv # For a library or package, you might want to ignore these files since the code is @@ -138,7 +138,7 @@ celerybeat.pid *.sage.py # Environments -# .env +.env .venv env/ venv/ diff --git a/vllm/distributed/weight_transfer/packed_tensor.py b/vllm/distributed/weight_transfer/packed_tensor.py index b1afdc124599..7366a939bf91 100644 --- a/vllm/distributed/weight_transfer/packed_tensor.py +++ b/vllm/distributed/weight_transfer/packed_tensor.py @@ -4,27 +4,11 @@ import math from collections.abc import Callable, Iterator -from functools import lru_cache from typing import Any import torch -# Configuration constants (can be overridden via environment variables later) -REFIT_BUFFER_MEMORY_RATIO = 0.02 -REFIT_NUM_BUFFERS = 2 -REFIT_MAX_BUFFER_SIZE = 5 * 1024**3 # 5GB max - - -@lru_cache(maxsize=1) -def get_target_packed_tensor_size() -> int: - """Calculate target packed tensor size based on GPU memory.""" - device = torch.device("cuda") - props = torch.cuda.get_device_properties(device) - total_memory_bytes = props.total_memory - target_size = min( - int(total_memory_bytes * REFIT_BUFFER_MEMORY_RATIO), REFIT_MAX_BUFFER_SIZE - ) - return target_size +from vllm import envs def packed_broadcast_producer( @@ -43,8 +27,8 @@ def packed_broadcast_producer( packing, should return a tensor """ - target_packed_tensor_size = get_target_packed_tensor_size() - num_buffers = REFIT_NUM_BUFFERS + target_packed_tensor_size = envs.VLLM_PACKED_TENSOR_BUFFER_SIZE + num_buffers = envs.VLLM_PACKED_TENSOR_NUM_BUFFERS streams = [torch.cuda.Stream() for _ in range(num_buffers)] buffer_idx = 0 @@ -133,8 +117,8 @@ def unpack_tensor( return unpacked_list - target_packed_tensor_size = get_target_packed_tensor_size() - num_buffers = REFIT_NUM_BUFFERS + target_packed_tensor_size = envs.VLLM_PACKED_TENSOR_BUFFER_SIZE + num_buffers = envs.VLLM_PACKED_TENSOR_NUM_BUFFERS streams = [torch.cuda.Stream() for _ in range(num_buffers)] buffer_idx = 0 diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c9b792005ac8..6f1554ab14c6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -595,9 +595,9 @@ def __post_init__(self): self.attention_config = AttentionConfig(**self.attention_config) if isinstance(self.eplb_config, dict): self.eplb_config = EPLBConfig(**self.eplb_config) - if isinstance(self.weight_transfer_config, dict): + if self.weight_transfer_backend is not None: self.weight_transfer_config = WeightTransferConfig( - **self.weight_transfer_config + backend=self.weight_transfer_backend ) # Setup plugins from vllm.plugins import load_general_plugins diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 6d6d5b405884..af52e420dfa7 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1825,10 +1825,10 @@ def update_weights(self, request: WeightUpdateRequest) -> None: request: Weight update request with backend-specific update info """ - if hasattr(request, "update_info"): + if isinstance(request, WeightUpdateRequest): update_info_dict = request.update_info else: - raise TypeError(f"Invalid `WeightUpdateRequest` format: {type(request)}") + raise TypeError(f"Expected WeightUpdateRequest, got {type(request)}") self.llm_engine.collective_rpc( "update_weights", kwargs={"update_info": update_info_dict} diff --git a/vllm/envs.py b/vllm/envs.py index dadb8c8a231c..468f3db2a836 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -246,6 +246,8 @@ VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_USE_V2_MODEL_RUNNER: bool = False VLLM_DEBUG_MFU_METRICS: bool = False + VLLM_PACKED_TENSOR_NUM_BUFFERS: int = 2 + VLLM_PACKED_TENSOR_BUFFER_SIZE: int = 1024 * 1024 * 1024 # 1GB def get_default_cache_root(): @@ -1575,6 +1577,14 @@ def get_vllm_port() -> int | None: "VLLM_DEBUG_MFU_METRICS": lambda: bool( int(os.getenv("VLLM_DEBUG_MFU_METRICS", "0")) ), + # Number of buffers for packed tensor weight transfer in NCCLWeightTransferEngine + "VLLM_PACKED_TENSOR_NUM_BUFFERS": lambda: int( + os.getenv("VLLM_PACKED_TENSOR_NUM_BUFFERS", "2") + ), + # Size in bytes for each packed tensor buffer (default 1GB) + "VLLM_PACKED_TENSOR_BUFFER_SIZE": lambda: int( + os.getenv("VLLM_PACKED_TENSOR_BUFFER_SIZE", str(1024 * 1024 * 1024)) + ), } # --8<-- [end:env-vars-definition] diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index b9eb4b997c44..e9553da13af1 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -902,10 +902,10 @@ async def update_weights(self, request: WeightUpdateRequest) -> None: request: Weight update request with backend-specific update info """ - if hasattr(request, "update_info"): + if isinstance(request, WeightUpdateRequest): update_info_dict = request.update_info else: - raise TypeError(f"Invalid WeightUpdateRequest format: {type(request)}") + raise TypeError(f"Expected WeightUpdateRequest, got {type(request)}") await self.collective_rpc( "update_weights", kwargs={"update_info": update_info_dict} From 437b14f3eb7c3707203ab3c498fa84f20cef6263 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Wed, 14 Jan 2026 15:28:23 -0800 Subject: [PATCH 11/36] precommit fix Signed-off-by: ahao-anyscale --- vllm/entrypoints/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index f79ae6f59cca..bf08bd3dd26f 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1846,7 +1846,7 @@ def finalize_weight_update(self) -> None: This should be called after all weight updates are complete. """ self.llm_engine.collective_rpc("finalize_weight_update") - + def __repr__(self) -> str: """Return a transformers-style hierarchical view of the model.""" # Cache the result to avoid repeated collective_rpc calls From fc43e5b04d3857034d7a3857922bbdf2c8bd7c86 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Thu, 15 Jan 2026 13:34:18 -0800 Subject: [PATCH 12/36] test fixes Signed-off-by: ahao-anyscale --- tests/distributed/test_packed_tensor.py | 310 +++++++----------- .../entrypoints/openai/test_openai_schema.py | 9 + .../test_weight_transfer_llm.py | 49 --- vllm/entrypoints/openai/api_server.py | 53 --- vllm/entrypoints/serve/rlhf/api_router.py | 57 +++- 5 files changed, 186 insertions(+), 292 deletions(-) diff --git a/tests/distributed/test_packed_tensor.py b/tests/distributed/test_packed_tensor.py index fc6ff91e4920..02a787e04c8c 100644 --- a/tests/distributed/test_packed_tensor.py +++ b/tests/distributed/test_packed_tensor.py @@ -6,11 +6,10 @@ These utilities enable efficient batched tensor transfer over NCCL. """ -from unittest.mock import patch - import pytest import torch +from vllm import envs from vllm.distributed.weight_transfer.nccl_engine import NCCLUpdateInfo @@ -95,7 +94,7 @@ def test_packed_can_be_set_true(self): class TestPackedBroadcastProducer: """Test packed_broadcast_producer function.""" - def test_producer_broadcasts_tensors(self): + def test_producer_broadcasts_tensors(self, monkeypatch): """Test that producer broadcasts all tensors.""" from vllm.distributed.weight_transfer.packed_tensor import ( packed_broadcast_producer, @@ -107,22 +106,19 @@ def test_producer_broadcasts_tensors(self): mock_group = MockCommunicationGroup() # Use a small target size to force multiple batches - with patch( - "vllm.distributed.weight_transfer.packed_tensor.get_target_packed_tensor_size", - return_value=500, - ): - packed_broadcast_producer( - iterator=iter(params_cuda), - group=mock_group, - src=0, - post_iter_func=lambda x: x[1], - ) + monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE", 500) + packed_broadcast_producer( + iterator=iter(params_cuda), + group=mock_group, + src=0, + post_iter_func=lambda x: x[1], + ) # Should have broadcasted some tensors assert mock_group.broadcast_count > 0 assert len(mock_group.broadcasted_tensors) > 0 - def test_producer_single_large_tensor(self): + def test_producer_single_large_tensor(self, monkeypatch): """Test with a single tensor larger than target size.""" from vllm.distributed.weight_transfer.packed_tensor import ( packed_broadcast_producer, @@ -135,16 +131,13 @@ def test_producer_single_large_tensor(self): mock_group = MockCommunicationGroup() # Small target size to force the tensor to exceed it - with patch( - "vllm.distributed.weight_transfer.packed_tensor.get_target_packed_tensor_size", - return_value=100, - ): - packed_broadcast_producer( - iterator=iter(params), - group=mock_group, - src=0, - post_iter_func=lambda x: x[1], - ) + monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE", 100) + packed_broadcast_producer( + iterator=iter(params), + group=mock_group, + src=0, + post_iter_func=lambda x: x[1], + ) # Should still broadcast the tensor (at least 1 broadcast) assert mock_group.broadcast_count >= 1 @@ -155,7 +148,7 @@ def test_producer_single_large_tensor(self): actual_size = sum(t.numel() for t in mock_group.broadcasted_tensors) assert actual_size == expected_size - def test_producer_multiple_batches(self): + def test_producer_multiple_batches(self, monkeypatch): """Test that tensors are properly batched when exceeding target size.""" from vllm.distributed.weight_transfer.packed_tensor import ( packed_broadcast_producer, @@ -170,16 +163,13 @@ def test_producer_multiple_batches(self): mock_group = MockCommunicationGroup() # Small target size to force multiple batches - with patch( - "vllm.distributed.weight_transfer.packed_tensor.get_target_packed_tensor_size", - return_value=2000, - ): - packed_broadcast_producer( - iterator=iter(params), - group=mock_group, - src=0, - post_iter_func=lambda x: x[1], - ) + monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE", 2000) + packed_broadcast_producer( + iterator=iter(params), + group=mock_group, + src=0, + post_iter_func=lambda x: x[1], + ) # Should have multiple broadcasts assert mock_group.broadcast_count > 1 @@ -189,7 +179,7 @@ def test_producer_multiple_batches(self): actual_total = sum(t.numel() for t in mock_group.broadcasted_tensors) assert actual_total == expected_total - def test_producer_empty_iterator(self): + def test_producer_empty_iterator(self, monkeypatch): """Test producer handles empty iterator gracefully.""" from vllm.distributed.weight_transfer.packed_tensor import ( packed_broadcast_producer, @@ -197,16 +187,13 @@ def test_producer_empty_iterator(self): mock_group = MockCommunicationGroup() - with patch( - "vllm.distributed.weight_transfer.packed_tensor.get_target_packed_tensor_size", - return_value=1000, - ): - packed_broadcast_producer( - iterator=iter([]), - group=mock_group, - src=0, - post_iter_func=lambda x: x[1], - ) + monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE", 1000) + packed_broadcast_producer( + iterator=iter([]), + group=mock_group, + src=0, + post_iter_func=lambda x: x[1], + ) # No broadcasts for empty iterator assert mock_group.broadcast_count == 0 @@ -219,7 +206,7 @@ def test_producer_empty_iterator(self): class TestPackedBroadcastConsumer: """Test packed_broadcast_consumer function.""" - def test_consumer_receives_tensors(self): + def test_consumer_receives_tensors(self, monkeypatch): """Test that consumer receives and unpacks tensors.""" from vllm.distributed.weight_transfer.packed_tensor import ( packed_broadcast_consumer, @@ -232,36 +219,33 @@ def test_consumer_receives_tensors(self): # First, run producer to get the broadcasted tensors producer_group = MockCommunicationGroup() - with patch( - "vllm.distributed.weight_transfer.packed_tensor.get_target_packed_tensor_size", - return_value=2000, - ): - packed_broadcast_producer( - iterator=iter(params_cuda), - group=producer_group, - src=0, - post_iter_func=lambda x: x[1], - ) + monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE", 2000) + packed_broadcast_producer( + iterator=iter(params_cuda), + group=producer_group, + src=0, + post_iter_func=lambda x: x[1], + ) - # Now run consumer with the broadcasted tensors - consumer_group = MockConsumerCommunicationGroup( - producer_group.broadcasted_tensors - ) + # Now run consumer with the broadcasted tensors + consumer_group = MockConsumerCommunicationGroup( + producer_group.broadcasted_tensors + ) - state_dict_info = create_state_dict_info(params_cuda) + state_dict_info = create_state_dict_info(params_cuda) - unpacked_tensors = {} + unpacked_tensors = {} - def post_unpack_func(tensor_list): - for name, tensor in tensor_list: - unpacked_tensors[name] = tensor.clone() + def post_unpack_func(tensor_list): + for name, tensor in tensor_list: + unpacked_tensors[name] = tensor.clone() - packed_broadcast_consumer( - iterator=iter(state_dict_info.items()), - group=consumer_group, - src=0, - post_unpack_func=post_unpack_func, - ) + packed_broadcast_consumer( + iterator=iter(state_dict_info.items()), + group=consumer_group, + src=0, + post_unpack_func=post_unpack_func, + ) # Verify all parameters were unpacked assert len(unpacked_tensors) == len(params) @@ -283,7 +267,7 @@ class TestPackedBroadcastRoundtrip: """Test producer-consumer roundtrip behavior.""" @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) - def test_roundtrip_different_dtypes(self, dtype): + def test_roundtrip_different_dtypes(self, dtype, monkeypatch): """Test roundtrip with different data types.""" from vllm.distributed.weight_transfer.packed_tensor import ( packed_broadcast_consumer, @@ -295,34 +279,31 @@ def test_roundtrip_different_dtypes(self, dtype): producer_group = MockCommunicationGroup() - with patch( - "vllm.distributed.weight_transfer.packed_tensor.get_target_packed_tensor_size", - return_value=1000, - ): - packed_broadcast_producer( - iterator=iter(params_cuda), - group=producer_group, - src=0, - post_iter_func=lambda x: x[1], - ) + monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE", 1000) + packed_broadcast_producer( + iterator=iter(params_cuda), + group=producer_group, + src=0, + post_iter_func=lambda x: x[1], + ) - consumer_group = MockConsumerCommunicationGroup( - producer_group.broadcasted_tensors - ) + consumer_group = MockConsumerCommunicationGroup( + producer_group.broadcasted_tensors + ) - state_dict_info = create_state_dict_info(params_cuda) - unpacked_tensors = {} + state_dict_info = create_state_dict_info(params_cuda) + unpacked_tensors = {} - def post_unpack_func(tensor_list): - for name, tensor in tensor_list: - unpacked_tensors[name] = tensor.clone() + def post_unpack_func(tensor_list): + for name, tensor in tensor_list: + unpacked_tensors[name] = tensor.clone() - packed_broadcast_consumer( - iterator=iter(state_dict_info.items()), - group=consumer_group, - src=0, - post_unpack_func=post_unpack_func, - ) + packed_broadcast_consumer( + iterator=iter(state_dict_info.items()), + group=consumer_group, + src=0, + post_unpack_func=post_unpack_func, + ) # Verify roundtrip preserves data for name, original_tensor in params_cuda: @@ -331,7 +312,7 @@ def post_unpack_func(tensor_list): assert unpacked.dtype == dtype assert torch.allclose(unpacked, original_tensor, rtol=1e-4, atol=1e-6) - def test_roundtrip_mixed_dtypes(self): + def test_roundtrip_mixed_dtypes(self, monkeypatch): """Test roundtrip with mixed data types.""" from vllm.distributed.weight_transfer.packed_tensor import ( packed_broadcast_consumer, @@ -347,34 +328,31 @@ def test_roundtrip_mixed_dtypes(self): producer_group = MockCommunicationGroup() - with patch( - "vllm.distributed.weight_transfer.packed_tensor.get_target_packed_tensor_size", - return_value=500, - ): - packed_broadcast_producer( - iterator=iter(params), - group=producer_group, - src=0, - post_iter_func=lambda x: x[1], - ) + monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE", 500) + packed_broadcast_producer( + iterator=iter(params), + group=producer_group, + src=0, + post_iter_func=lambda x: x[1], + ) - consumer_group = MockConsumerCommunicationGroup( - producer_group.broadcasted_tensors - ) + consumer_group = MockConsumerCommunicationGroup( + producer_group.broadcasted_tensors + ) - state_dict_info = create_state_dict_info(params) - unpacked_tensors = {} + state_dict_info = create_state_dict_info(params) + unpacked_tensors = {} - def post_unpack_func(tensor_list): - for name, tensor in tensor_list: - unpacked_tensors[name] = tensor.clone() + def post_unpack_func(tensor_list): + for name, tensor in tensor_list: + unpacked_tensors[name] = tensor.clone() - packed_broadcast_consumer( - iterator=iter(state_dict_info.items()), - group=consumer_group, - src=0, - post_unpack_func=post_unpack_func, - ) + packed_broadcast_consumer( + iterator=iter(state_dict_info.items()), + group=consumer_group, + src=0, + post_unpack_func=post_unpack_func, + ) # Verify all params roundtrip correctly with correct dtypes for name, original_tensor in params: @@ -385,7 +363,7 @@ def post_unpack_func(tensor_list): assert torch.allclose(unpacked, original_tensor, rtol=1e-4, atol=1e-6) @pytest.mark.parametrize("target_size", [100, 1000, 10000, 100000]) - def test_roundtrip_different_batch_sizes(self, target_size): + def test_roundtrip_different_batch_sizes(self, target_size, monkeypatch): """Test roundtrip with different target batch sizes.""" from vllm.distributed.weight_transfer.packed_tensor import ( packed_broadcast_consumer, @@ -397,34 +375,31 @@ def test_roundtrip_different_batch_sizes(self, target_size): producer_group = MockCommunicationGroup() - with patch( - "vllm.distributed.weight_transfer.packed_tensor.get_target_packed_tensor_size", - return_value=target_size, - ): - packed_broadcast_producer( - iterator=iter(params_cuda), - group=producer_group, - src=0, - post_iter_func=lambda x: x[1], - ) + monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE", target_size) + packed_broadcast_producer( + iterator=iter(params_cuda), + group=producer_group, + src=0, + post_iter_func=lambda x: x[1], + ) - consumer_group = MockConsumerCommunicationGroup( - producer_group.broadcasted_tensors - ) + consumer_group = MockConsumerCommunicationGroup( + producer_group.broadcasted_tensors + ) - state_dict_info = create_state_dict_info(params_cuda) - unpacked_tensors = {} + state_dict_info = create_state_dict_info(params_cuda) + unpacked_tensors = {} - def post_unpack_func(tensor_list): - for name, tensor in tensor_list: - unpacked_tensors[name] = tensor.clone() + def post_unpack_func(tensor_list): + for name, tensor in tensor_list: + unpacked_tensors[name] = tensor.clone() - packed_broadcast_consumer( - iterator=iter(state_dict_info.items()), - group=consumer_group, - src=0, - post_unpack_func=post_unpack_func, - ) + packed_broadcast_consumer( + iterator=iter(state_dict_info.items()), + group=consumer_group, + src=0, + post_unpack_func=post_unpack_func, + ) # Verify all params roundtrip correctly assert len(unpacked_tensors) == len(params) @@ -433,44 +408,3 @@ def post_unpack_func(tensor_list): assert torch.allclose( unpacked_tensors[name], original_tensor, rtol=1e-5, atol=1e-7 ) - - -# --- Unit Tests: get_target_packed_tensor_size --- - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -class TestGetTargetPackedTensorSize: - """Test get_target_packed_tensor_size function.""" - - def test_returns_positive_value(self): - """Test that function returns a positive value.""" - from vllm.distributed.weight_transfer.packed_tensor import ( - get_target_packed_tensor_size, - ) - - # Clear cache to get fresh value - get_target_packed_tensor_size.cache_clear() - size = get_target_packed_tensor_size() - assert size > 0 - - def test_is_cached(self): - """Test that function result is cached.""" - from vllm.distributed.weight_transfer.packed_tensor import ( - get_target_packed_tensor_size, - ) - - get_target_packed_tensor_size.cache_clear() - size1 = get_target_packed_tensor_size() - size2 = get_target_packed_tensor_size() - assert size1 == size2 - - def test_respects_max_buffer_size(self): - """Test that function respects max buffer size.""" - from vllm.distributed.weight_transfer.packed_tensor import ( - REFIT_MAX_BUFFER_SIZE, - get_target_packed_tensor_size, - ) - - get_target_packed_tensor_size.cache_clear() - size = get_target_packed_tensor_size() - assert size <= REFIT_MAX_BUFFER_SIZE diff --git a/tests/entrypoints/openai/test_openai_schema.py b/tests/entrypoints/openai/test_openai_schema.py index 50d24a400549..9e7758329904 100644 --- a/tests/entrypoints/openai/test_openai_schema.py +++ b/tests/entrypoints/openai/test_openai_schema.py @@ -139,6 +139,15 @@ def test_openapi_stateless(case: schemathesis.Case): # Skip responses API as it is meant to be stateful. return + # Skip weight transfer endpoints as they require special setup + # (weight_transfer_config) and are meant to be stateful. + if case.operation.path in ( + "/init_weight_transfer", + "/update_weights", + "/finalize_weight_update", + ): + return + timeout = { # requires a longer timeout ("POST", "/v1/chat/completions"): LONG_TIMEOUT_SECONDS, diff --git a/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py b/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py index 9f1d655c5dfe..22505bbdea51 100644 --- a/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py +++ b/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py @@ -306,55 +306,6 @@ def check_flow(self): assert result["update_names"] == ["test.weight"] -@create_new_process_for_each_test() -def test_weight_transfer_with_tp(): - """Test weight transfer works correctly with tensor parallelism.""" - if torch.cuda.device_count() < 1: - pytest.skip("Need at least 1 GPU for this test") - - # Enable insecure serialization to allow pickling functions for collective_rpc - os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" - - with patch( - "vllm.distributed.weight_transfer.init_transfer_engine", - mock_init_transfer_engine, - ): - llm = LLM( - model=MODEL_NAME, - enforce_eager=True, - load_format="dummy", - tensor_parallel_size=1, - distributed_executor_backend="ray", - weight_transfer_config=WeightTransferConfig(backend="nccl"), - ) - - # Run weight transfer - llm.init_weight_transfer( - WeightTransferInitRequest(init_info={"test_param": "tp_test"}) - ) - - llm.update_weights( - WeightUpdateRequest( - update_info={ - "names": ["w"], - "dtype_names": ["float16"], - "shapes": [[50]], - } - ) - ) - - llm.finalize_weight_update() - - # Verify the worker processed the weight transfer - def check_processed(self): - engine = self.weight_transfer_engine - return engine.init_transfer_called and engine.receive_weights_called - - results = llm.collective_rpc(check_processed) - assert len(results) == 1 - assert all(results), "Worker should process weight transfer" - - @create_new_process_for_each_test() def test_weight_transfer_config_backend(): """Test that WeightTransferConfig backend is properly configured.""" diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 3b74d249b33d..39d07f6d2547 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -31,8 +31,6 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send import vllm.envs as envs -from vllm.distributed.weight_transfer import WeightUpdateRequest -from vllm.distributed.weight_transfer.base import WeightTransferInitRequest from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import EngineClient from vllm.entrypoints.anthropic.protocol import ( @@ -390,57 +388,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request): return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/init_weight_transfer") -async def init_weight_transfer(raw_request: Request): - try: - body = await raw_request.json() - except json.JSONDecodeError as e: - raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904 - init_info = body.get("init_info") - if init_info is None: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, - detail="Missing 'init_info' in request body", - ) - await engine_client(raw_request).init_weight_transfer( - WeightTransferInitRequest(init_info=init_info) - ) - return JSONResponse(content={"message": "Weight transfer initialized"}) - - -@router.post("/update_weights") -async def update_weights(raw_request: Request): - try: - body = await raw_request.json() - except json.JSONDecodeError as e: - raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904 - update_info = body.get("update_info") - if update_info is None: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, - detail="Missing 'update_info' in request body", - ) - await engine_client(raw_request).update_weights( - request=WeightUpdateRequest(update_info=update_info) - ) - return JSONResponse(content={"message": "Weights updated"}) - - -@router.post("/finalize_weight_update") -async def finalize_weight_update(raw_request: Request): - await engine_client(raw_request).finalize_weight_update() - return JSONResponse(content={"message": "Weight update finalized"}) - - -@router.get("/get_world_size") -async def get_world_size(raw_request: Request): - """Get the world size from the parallel config (TP * PP * DP).""" - world_size = engine_client( - raw_request - ).vllm_config.parallel_config.world_size_across_dp - return JSONResponse(content={"world_size": world_size}) - - def load_log_config(log_config_file: str | None) -> dict | None: if not log_config_file: return None diff --git a/vllm/entrypoints/serve/rlhf/api_router.py b/vllm/entrypoints/serve/rlhf/api_router.py index 3b37840ae089..04b17e9c3b34 100644 --- a/vllm/entrypoints/serve/rlhf/api_router.py +++ b/vllm/entrypoints/serve/rlhf/api_router.py @@ -1,12 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import json from http import HTTPStatus -from fastapi import APIRouter, FastAPI, Query, Request +from fastapi import APIRouter, FastAPI, HTTPException, Query, Request from fastapi.responses import JSONResponse +from vllm.distributed.weight_transfer import WeightUpdateRequest +from vllm.distributed.weight_transfer.base import WeightTransferInitRequest from vllm.engine.protocol import EngineClient from vllm.logger import init_logger @@ -98,5 +100,56 @@ async def is_paused(raw_request: Request) -> JSONResponse: return JSONResponse(content={"is_paused": paused}) +@router.post("/init_weight_transfer") +async def init_weight_transfer(raw_request: Request): + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904 + init_info = body.get("init_info") + if init_info is None: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail="Missing 'init_info' in request body", + ) + await engine_client(raw_request).init_weight_transfer( + WeightTransferInitRequest(init_info=init_info) + ) + return JSONResponse(content={"message": "Weight transfer initialized"}) + + +@router.post("/update_weights") +async def update_weights(raw_request: Request): + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904 + update_info = body.get("update_info") + if update_info is None: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail="Missing 'update_info' in request body", + ) + await engine_client(raw_request).update_weights( + request=WeightUpdateRequest(update_info=update_info) + ) + return JSONResponse(content={"message": "Weights updated"}) + + +@router.post("/finalize_weight_update") +async def finalize_weight_update(raw_request: Request): + await engine_client(raw_request).finalize_weight_update() + return JSONResponse(content={"message": "Weight update finalized"}) + + +@router.get("/get_world_size") +async def get_world_size(raw_request: Request): + """Get the world size from the parallel config (TP * PP * DP).""" + world_size = engine_client( + raw_request + ).vllm_config.parallel_config.world_size_across_dp + return JSONResponse(content={"world_size": world_size}) + + def attach_router(app: FastAPI): app.include_router(router) From e2dc668c96092fdc758906df855b835e5504fb14 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Thu, 15 Jan 2026 16:30:32 -0800 Subject: [PATCH 13/36] fix examples Signed-off-by: ahao-anyscale --- .buildkite/test-amd.yaml | 8 ++++---- .buildkite/test-pipeline.yaml | 8 ++++---- .buildkite/test_areas/distributed.yaml | 8 ++++---- docs/training/rlhf.md | 5 ++--- examples/offline_inference/weight_syncing/rlhf.py | 3 +-- .../weight_syncing/rlhf_async_new_apis.py | 1 - examples/offline_inference/weight_syncing/rlhf_ipc.py | 3 +-- 7 files changed, 16 insertions(+), 20 deletions(-) diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 044a82c9773f..d2f33be0477b 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -227,8 +227,7 @@ steps: - tests/distributed/test_pynccl - tests/distributed/test_events - tests/compile/fullgraph/test_basic_correctness.py - - examples/offline_inference/rlhf.py - - examples/offline_inference/rlhf_colocate.py + - examples/offline_inference/weight_syncing/ - tests/examples/offline_inference/data_parallel.py - tests/v1/distributed - tests/v1/engine/test_engine_core_client.py @@ -264,9 +263,10 @@ steps: - pytest -v -s distributed/test_symm_mem_allreduce.py # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests - - pushd ../examples/offline_inference + - pushd ../examples/offline_inference/weight_syncing - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py - - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py - popd - label: Distributed Tests (8 GPUs) # 4min diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d88eedd7a034..592c71cd2bc1 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -200,8 +200,7 @@ steps: - tests/distributed/test_pynccl - tests/distributed/test_events - tests/compile/fullgraph/test_basic_correctness.py - - examples/offline_inference/rlhf.py - - examples/offline_inference/rlhf_colocate.py + - examples/offline_inference/weight_syncing/ - tests/examples/offline_inference/data_parallel.py - tests/v1/distributed - tests/v1/engine/test_engine_core_client.py @@ -236,9 +235,10 @@ steps: - pytest -v -s distributed/test_symm_mem_allreduce.py # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests - - pushd ../examples/offline_inference + - pushd ../examples/offline_inference/weight_syncing - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py - - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py - popd - label: Distributed Tests (8 GPUs) # 4min diff --git a/.buildkite/test_areas/distributed.yaml b/.buildkite/test_areas/distributed.yaml index c88076bb528e..330142452725 100644 --- a/.buildkite/test_areas/distributed.yaml +++ b/.buildkite/test_areas/distributed.yaml @@ -61,8 +61,7 @@ steps: - tests/distributed/test_pynccl - tests/distributed/test_events - tests/compile/fullgraph/test_basic_correctness.py - - examples/offline_inference/rlhf.py - - examples/offline_inference/rlhf_colocate.py + - examples/offline_inference/weight_syncing/ - tests/examples/offline_inference/data_parallel.py - tests/v1/distributed - tests/v1/engine/test_engine_core_client.py @@ -97,9 +96,10 @@ steps: - pytest -v -s distributed/test_symm_mem_allreduce.py # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests - - cd ../examples/offline_inference + - cd ../examples/offline_inference/weight_syncing - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py - - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py - label: Distributed Tests (8 GPUs)(H100) timeout_in_minutes: 10 diff --git a/docs/training/rlhf.md b/docs/training/rlhf.md index 0b7e384dc8d6..e6ca23e4fe62 100644 --- a/docs/training/rlhf.md +++ b/docs/training/rlhf.md @@ -18,9 +18,8 @@ The following open-source RL libraries use vLLM for fast rollouts (sorted alphab See the following basic examples to get started if you don't want to use an existing library: -- [Training and inference processes are located on separate GPUs (inspired by OpenRLHF)](../examples/offline_inference/rlhf.md) -- [Training and inference processes are colocated on the same GPUs using Ray](../examples/offline_inference/rlhf_colocate.md) -- [Utilities for performing RLHF with vLLM](../examples/offline_inference/rlhf_utils.md) +- [Training and inference processes are located on separate GPUs](../examples/offline_inference/weight_syncing/rlhf.md) +- [Training and inference processes communicate via IPC](../examples/offline_inference/weight_syncing/rlhf_ipc.md) See the following notebooks showing how to use vLLM for GRPO: diff --git a/examples/offline_inference/weight_syncing/rlhf.py b/examples/offline_inference/weight_syncing/rlhf.py index 473e6a9a840d..e772ce0f6f05 100644 --- a/examples/offline_inference/weight_syncing/rlhf.py +++ b/examples/offline_inference/weight_syncing/rlhf.py @@ -48,8 +48,7 @@ ) from vllm.utils.network_utils import get_ip, get_open_port -MODEL_NAME = "Qwen/Qwen3-30B-A3B-Thinking-2507" -# MODEL_NAME = "facebook/opt-125m" +MODEL_NAME = "facebook/opt-125m" class MyLLM(LLM): diff --git a/examples/offline_inference/weight_syncing/rlhf_async_new_apis.py b/examples/offline_inference/weight_syncing/rlhf_async_new_apis.py index 819c592a8caa..45cae4a25bf1 100644 --- a/examples/offline_inference/weight_syncing/rlhf_async_new_apis.py +++ b/examples/offline_inference/weight_syncing/rlhf_async_new_apis.py @@ -47,7 +47,6 @@ ) from vllm.utils.network_utils import get_ip, get_open_port -# MODEL_NAME = "Qwen/Qwen3-30B-A3B-Thinking-2507" MODEL_NAME = "facebook/opt-125m" diff --git a/examples/offline_inference/weight_syncing/rlhf_ipc.py b/examples/offline_inference/weight_syncing/rlhf_ipc.py index 1d7bff638b25..30167de343f0 100644 --- a/examples/offline_inference/weight_syncing/rlhf_ipc.py +++ b/examples/offline_inference/weight_syncing/rlhf_ipc.py @@ -61,8 +61,7 @@ def get_physical_gpu_id(): # Load the OPT-125M model onto GPU 0 for the training workload. -# MODEL_NAME = "facebook/opt-125m" -MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" +MODEL_NAME = "facebook/opt-125m" @ray.remote From 50b60392efb79797c9bb441d21514919c45315bb Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Tue, 20 Jan 2026 11:49:53 -0800 Subject: [PATCH 14/36] x Signed-off-by: ahao-anyscale --- .../offline_inference/weight_syncing/rlhf.py | 82 +++++++++++++------ 1 file changed, 55 insertions(+), 27 deletions(-) diff --git a/examples/offline_inference/weight_syncing/rlhf.py b/examples/offline_inference/weight_syncing/rlhf.py index e772ce0f6f05..c0d61ec0e897 100644 --- a/examples/offline_inference/weight_syncing/rlhf.py +++ b/examples/offline_inference/weight_syncing/rlhf.py @@ -61,22 +61,61 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) -torch.cuda.set_device("cuda:0") -# Load the OPT-125M model onto GPU 0 for the training workload. -train_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16) -train_model.to("cuda:0") +@ray.remote(num_gpus=1) +class TrainModel: + """Ray actor that wraps the training model on a dedicated GPU.""" + + def __init__(self, model_name: str): + self.device = torch.device("cuda:0") + torch.cuda.set_device(self.device) + self.model = AutoModelForCausalLM.from_pretrained( + model_name, dtype=torch.bfloat16 + ) + self.model.to(self.device) + + def zero_weights(self): + """Zero out all model weights (simulates training step).""" + for name, p in self.model.named_parameters(): + p.data.zero_() + + def get_weight_metadata(self): + """Return weight names, dtypes, and shapes for weight transfer.""" + names = [] + dtype_names = [] + shapes = [] + for name, p in self.model.named_parameters(): + names.append(name) + dtype_names.append(str(p.dtype).split(".")[-1]) + shapes.append(list(p.shape)) + return names, dtype_names, shapes + + def init_weight_transfer_group(self, master_address, master_port, world_size): + """Initialize the NCCL process group for weight transfer.""" + self.model_update_group = stateless_init_process_group( + master_address, master_port, 0, world_size, self.device + ) + + def broadcast_weights(self, packed: bool = True): + """Broadcast weights to the inference engine.""" + NCCLWeightTransferEngine.trainer_broadcast_weights( + iterator=self.model.named_parameters(), + group=self.model_update_group, + packed=packed, + ) + # Initialize Ray and set the visible devices. The vLLM engine will # be placed on GPUs 1 and 2. -os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" +os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2" ray.init(runtime_env={"excludes": [".git/objects/pack/"]}) # ray.init() # Create a placement group that reserves GPU 1–2 for the vLLM inference engine. # Learn more about Ray placement groups: # https://docs.ray.io/en/latest/placement-groups.html -pg_training = placement_group([{"GPU": 1, "CPU": 0}]) -ray.get(pg_training.ready()) +# Launch the training model actor. Ray's resource scheduler will allocate +# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs. +train_model = TrainModel.remote(MODEL_NAME) pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2) ray.get(pg_inference.ready()) @@ -142,26 +181,20 @@ def __init__(self, *args, **kwargs): ) ) -model_update_group = stateless_init_process_group( - master_address, master_port, 0, world_size, torch.device("cuda:0") +# Initialize weight transfer group on both the training actor and inference engine +train_handle = train_model.init_weight_transfer_group.remote( + master_address, master_port, world_size ) -ray.get(handle) +ray.get([train_handle, handle]) # Simulate a training step by zeroing out all model weights. # In a real RLHF training loop the weights would be updated using the gradient # from an RL objective such as PPO on a reward model. -for name, p in train_model.named_parameters(): - p.data.zero_() +ray.get(train_model.zero_weights.remote()) # Synchronize the updated weights to the inference engine using batched API. -# Collect all weight metadata -names = [] -dtype_names = [] -shapes = [] -for name, p in train_model.named_parameters(): - names.append(name) - dtype_names.append(str(p.dtype).split(".")[-1]) - shapes.append(p.shape) +# Collect all weight metadata from the training actor +names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote()) # Issue update_weights call with NCCL-specific update info # packed=True enables efficient batched tensor broadcasting @@ -179,13 +212,8 @@ def __init__(self, *args, **kwargs): ) # Broadcast all weights from trainer using the weight transfer API -NCCLWeightTransferEngine.trainer_broadcast_weights( - iterator=train_model.named_parameters(), - group=model_update_group, - packed=True, -) - -ray.get(handle) +broadcast_handle = train_model.broadcast_weights.remote(packed=True) +ray.get([broadcast_handle, handle]) # Finalize the weight update (processes weights for quantization/kernel format) ray.get(llm.finalize_weight_update.remote()) From 8de0daa2713da1846b852d012170507873b57040 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Wed, 21 Jan 2026 07:33:40 +0000 Subject: [PATCH 15/36] x Signed-off-by: ahao-anyscale --- docs/training/rlhf.md | 4 ++-- vllm/engine/arg_utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/training/rlhf.md b/docs/training/rlhf.md index e6ca23e4fe62..2d01686b8f5f 100644 --- a/docs/training/rlhf.md +++ b/docs/training/rlhf.md @@ -18,8 +18,8 @@ The following open-source RL libraries use vLLM for fast rollouts (sorted alphab See the following basic examples to get started if you don't want to use an existing library: -- [Training and inference processes are located on separate GPUs](../examples/offline_inference/weight_syncing/rlhf.md) -- [Training and inference processes communicate via IPC](../examples/offline_inference/weight_syncing/rlhf_ipc.md) +- [Training and inference processes are located on separate GPUs](../examples/offline_inference/weight_syncing.md#rlhf) +- [Training and inference processes communicate via IPC](../examples/offline_inference/weight_syncing.md#rlhf-ipc) See the following notebooks showing how to use vLLM for GRPO: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 538a58cf59a3..8a180cd5906c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -579,7 +579,7 @@ class EngineArgs: kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend tokens_only: bool = False - weight_transfer_backend: str = "nccl" + weight_transfer_backend: str | None = None """Backend for weight transfer during RL training. Options: nccl, ipc""" weight_transfer_config: WeightTransferConfig = field( From ddb178d342191c654baf0768544a8ae15ac37e93 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Wed, 21 Jan 2026 11:26:28 -0800 Subject: [PATCH 16/36] x Signed-off-by: ahao-anyscale --- .../offline_inference/weight_syncing/rlhf.py | 10 +- .../weight_syncing/rlhf_async_new_apis.py | 121 ++++++++++++------ 2 files changed, 90 insertions(+), 41 deletions(-) diff --git a/examples/offline_inference/weight_syncing/rlhf.py b/examples/offline_inference/weight_syncing/rlhf.py index c0d61ec0e897..b322ba24ddf0 100644 --- a/examples/offline_inference/weight_syncing/rlhf.py +++ b/examples/offline_inference/weight_syncing/rlhf.py @@ -168,7 +168,7 @@ def broadcast_weights(self, packed: bool = True): master_port = get_open_port() world_size = ray.get(llm.get_world_size.remote()) + 1 # +1 for the trainer -handle = llm.init_weight_transfer.remote( +inference_handle = llm.init_weight_transfer.remote( WeightTransferInitRequest( init_info=asdict( NCCLInitInfo( @@ -185,7 +185,7 @@ def broadcast_weights(self, packed: bool = True): train_handle = train_model.init_weight_transfer_group.remote( master_address, master_port, world_size ) -ray.get([train_handle, handle]) +ray.get([train_handle, inference_handle]) # Simulate a training step by zeroing out all model weights. # In a real RLHF training loop the weights would be updated using the gradient @@ -198,7 +198,7 @@ def broadcast_weights(self, packed: bool = True): # Issue update_weights call with NCCL-specific update info # packed=True enables efficient batched tensor broadcasting -handle = llm.update_weights.remote( +inference_handle = llm.update_weights.remote( WeightUpdateRequest( update_info=asdict( NCCLUpdateInfo( @@ -212,8 +212,8 @@ def broadcast_weights(self, packed: bool = True): ) # Broadcast all weights from trainer using the weight transfer API -broadcast_handle = train_model.broadcast_weights.remote(packed=True) -ray.get([broadcast_handle, handle]) +train_handle = train_model.broadcast_weights.remote(packed=True) +ray.get([train_handle, inference_handle]) # Finalize the weight update (processes weights for quantization/kernel format) ray.get(llm.finalize_weight_update.remote()) diff --git a/examples/offline_inference/weight_syncing/rlhf_async_new_apis.py b/examples/offline_inference/weight_syncing/rlhf_async_new_apis.py index 45cae4a25bf1..90e1ef1457bc 100644 --- a/examples/offline_inference/weight_syncing/rlhf_async_new_apis.py +++ b/examples/offline_inference/weight_syncing/rlhf_async_new_apis.py @@ -30,6 +30,7 @@ import asyncio import os import uuid +from dataclasses import asdict import ray import torch @@ -45,11 +46,59 @@ WeightTransferInitRequest, WeightUpdateRequest, ) +from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLInitInfo, + NCCLUpdateInfo, + NCCLWeightTransferEngine, +) from vllm.utils.network_utils import get_ip, get_open_port MODEL_NAME = "facebook/opt-125m" +@ray.remote(num_gpus=1) +class TrainModel: + """Ray actor that wraps the training model on a dedicated GPU.""" + + def __init__(self, model_name: str): + self.device = torch.device("cuda:0") + torch.cuda.set_device(self.device) + self.model = AutoModelForCausalLM.from_pretrained( + model_name, dtype=torch.bfloat16 + ) + self.model.to(self.device) + + def zero_weights(self): + """Zero out all model weights (simulates training step).""" + for name, p in self.model.named_parameters(): + p.data.zero_() + + def get_weight_metadata(self): + """Return weight names, dtypes, and shapes for weight transfer.""" + names = [] + dtype_names = [] + shapes = [] + for name, p in self.model.named_parameters(): + names.append(name) + dtype_names.append(str(p.dtype).split(".")[-1]) + shapes.append(list(p.shape)) + return names, dtype_names, shapes + + def init_weight_transfer_group(self, master_address, master_port, world_size): + """Initialize the NCCL process group for weight transfer.""" + self.model_update_group = stateless_init_process_group( + master_address, master_port, 0, world_size, self.device + ) + + def broadcast_weights(self, packed: bool = True): + """Broadcast weights to the inference engine.""" + NCCLWeightTransferEngine.trainer_broadcast_weights( + iterator=self.model.named_parameters(), + group=self.model_update_group, + packed=packed, + ) + + class MyLLM: """Simple wrapper over AsyncLLM for supporting async RL.""" @@ -109,21 +158,19 @@ async def finalize_weight_update(self) -> None: return await self.engine.finalize_weight_update() -# Load the OPT-125M model onto GPU 0 for the training workload. -train_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16) -train_model.to("cuda:0") - # Initialize Ray and set the visible devices. The vLLM engine will # be placed on GPUs 1 and 2. -os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" +os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2" ray.init(runtime_env={"excludes": [".git/objects/pack/"]}) # ray.init() +# Launch the training model actor. Ray's resource scheduler will allocate +# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs. +train_model = TrainModel.remote(MODEL_NAME) + # Create a placement group that reserves GPU 1–2 for the vLLM inference engine. # Learn more about Ray placement groups: # https://docs.ray.io/en/latest/placement-groups.html -pg_training = placement_group([{"GPU": 1, "CPU": 0}]) -ray.get(pg_training.ready()) pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2) ray.get(pg_inference.ready()) @@ -169,22 +216,25 @@ async def finalize_weight_update(self) -> None: master_address = get_ip() master_port = get_open_port() -print("reached init weight in driver") -handle = llm.init_weight_transfer.remote( +world_size = 3 # 1 trainer + 2 inference workers (tensor_parallel_size=2) +inference_handle = llm.init_weight_transfer.remote( WeightTransferInitRequest( - init_info=dict( - master_address=master_address, - master_port=master_port, - rank_offset=1, - world_size=3, + init_info=asdict( + NCCLInitInfo( + master_address=master_address, + master_port=master_port, + rank_offset=1, + world_size=world_size, + ) ) ) ) -model_update_group = stateless_init_process_group( - master_address, master_port, 0, 3, torch.device("cuda:0") +# Initialize weight transfer group on both the training actor and inference engine +train_handle = train_model.init_weight_transfer_group.remote( + master_address, master_port, world_size ) -ray.get(handle) +ray.get([train_handle, inference_handle]) generation_futures = [ @@ -200,31 +250,30 @@ async def finalize_weight_update(self) -> None: # Simulate a training step by zeroing out all model weights. # In a real RLHF training loop the weights would be updated using the gradient # from an RL objective such as PPO on a reward model. -for name, p in train_model.named_parameters(): - p.data.zero_() +ray.get(train_model.zero_weights.remote()) # Synchronize the updated weights to the inference engine using batched API. -# Collect all weight metadata -names = [] -dtype_names = [] -shapes = [] -for name, p in train_model.named_parameters(): - names.append(name) - dtype_names.append(str(p.dtype).split(".")[-1]) - shapes.append(p.shape) - -# Issue update_weights call -handle = llm.update_weights.remote( +# Collect all weight metadata from the training actor +names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote()) + +# Issue update_weights call with NCCL-specific update info +# packed=True enables efficient batched tensor broadcasting +inference_handle = llm.update_weights.remote( WeightUpdateRequest( - update_info=dict(names=names, dtype_names=dtype_names, shapes=shapes) + update_info=asdict( + NCCLUpdateInfo( + names=names, + dtype_names=dtype_names, + shapes=shapes, + packed=True, + ) + ) ) ) -# Broadcast all weights from trainer -for name, p in train_model.named_parameters(): - model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream()) - -ray.get(handle) +# Broadcast all weights from trainer using the weight transfer API +train_handle = train_model.broadcast_weights.remote(packed=True) +ray.get([train_handle, inference_handle]) # Finalize the weight update (processes weights for quantization/kernel format) ray.get(llm.finalize_weight_update.remote()) From 4d69ed3a3f8da21f9aef9741658a51607d351678 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Mon, 26 Jan 2026 11:52:48 -0800 Subject: [PATCH 17/36] removed ipc Signed-off-by: ahao-anyscale --- docs/training/rlhf.md | 5 +- .../rlhf.py | 0 .../rlhf_async_new_apis.py | 0 .../rlhf_http.py | 0 .../rlhf_utils.py | 0 .../{weight_syncing/legacy => }/rlhf.py | 0 .../legacy => }/rlhf_colocate.py | 0 .../legacy => }/rlhf_online_quant.py | 0 .../{weight_syncing/legacy => }/rlhf_utils.py | 0 .../weight_syncing/rlhf_ipc.py | 186 ------------------ vllm/config/weight_transfer.py | 2 +- vllm/distributed/weight_transfer/__init__.py | 5 - vllm/distributed/weight_transfer/base.py | 4 +- .../distributed/weight_transfer/ipc_engine.py | 129 ------------ vllm/engine/arg_utils.py | 6 +- 15 files changed, 9 insertions(+), 328 deletions(-) rename examples/offline_inference/{weight_syncing => new_weight_syncing}/rlhf.py (100%) rename examples/offline_inference/{weight_syncing => new_weight_syncing}/rlhf_async_new_apis.py (100%) rename examples/offline_inference/{weight_syncing => new_weight_syncing}/rlhf_http.py (100%) rename examples/offline_inference/{weight_syncing => new_weight_syncing}/rlhf_utils.py (100%) rename examples/offline_inference/{weight_syncing/legacy => }/rlhf.py (100%) rename examples/offline_inference/{weight_syncing/legacy => }/rlhf_colocate.py (100%) rename examples/offline_inference/{weight_syncing/legacy => }/rlhf_online_quant.py (100%) rename examples/offline_inference/{weight_syncing/legacy => }/rlhf_utils.py (100%) delete mode 100644 examples/offline_inference/weight_syncing/rlhf_ipc.py delete mode 100644 vllm/distributed/weight_transfer/ipc_engine.py diff --git a/docs/training/rlhf.md b/docs/training/rlhf.md index 2d01686b8f5f..0b7e384dc8d6 100644 --- a/docs/training/rlhf.md +++ b/docs/training/rlhf.md @@ -18,8 +18,9 @@ The following open-source RL libraries use vLLM for fast rollouts (sorted alphab See the following basic examples to get started if you don't want to use an existing library: -- [Training and inference processes are located on separate GPUs](../examples/offline_inference/weight_syncing.md#rlhf) -- [Training and inference processes communicate via IPC](../examples/offline_inference/weight_syncing.md#rlhf-ipc) +- [Training and inference processes are located on separate GPUs (inspired by OpenRLHF)](../examples/offline_inference/rlhf.md) +- [Training and inference processes are colocated on the same GPUs using Ray](../examples/offline_inference/rlhf_colocate.md) +- [Utilities for performing RLHF with vLLM](../examples/offline_inference/rlhf_utils.md) See the following notebooks showing how to use vLLM for GRPO: diff --git a/examples/offline_inference/weight_syncing/rlhf.py b/examples/offline_inference/new_weight_syncing/rlhf.py similarity index 100% rename from examples/offline_inference/weight_syncing/rlhf.py rename to examples/offline_inference/new_weight_syncing/rlhf.py diff --git a/examples/offline_inference/weight_syncing/rlhf_async_new_apis.py b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py similarity index 100% rename from examples/offline_inference/weight_syncing/rlhf_async_new_apis.py rename to examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py diff --git a/examples/offline_inference/weight_syncing/rlhf_http.py b/examples/offline_inference/new_weight_syncing/rlhf_http.py similarity index 100% rename from examples/offline_inference/weight_syncing/rlhf_http.py rename to examples/offline_inference/new_weight_syncing/rlhf_http.py diff --git a/examples/offline_inference/weight_syncing/rlhf_utils.py b/examples/offline_inference/new_weight_syncing/rlhf_utils.py similarity index 100% rename from examples/offline_inference/weight_syncing/rlhf_utils.py rename to examples/offline_inference/new_weight_syncing/rlhf_utils.py diff --git a/examples/offline_inference/weight_syncing/legacy/rlhf.py b/examples/offline_inference/rlhf.py similarity index 100% rename from examples/offline_inference/weight_syncing/legacy/rlhf.py rename to examples/offline_inference/rlhf.py diff --git a/examples/offline_inference/weight_syncing/legacy/rlhf_colocate.py b/examples/offline_inference/rlhf_colocate.py similarity index 100% rename from examples/offline_inference/weight_syncing/legacy/rlhf_colocate.py rename to examples/offline_inference/rlhf_colocate.py diff --git a/examples/offline_inference/weight_syncing/legacy/rlhf_online_quant.py b/examples/offline_inference/rlhf_online_quant.py similarity index 100% rename from examples/offline_inference/weight_syncing/legacy/rlhf_online_quant.py rename to examples/offline_inference/rlhf_online_quant.py diff --git a/examples/offline_inference/weight_syncing/legacy/rlhf_utils.py b/examples/offline_inference/rlhf_utils.py similarity index 100% rename from examples/offline_inference/weight_syncing/legacy/rlhf_utils.py rename to examples/offline_inference/rlhf_utils.py diff --git a/examples/offline_inference/weight_syncing/rlhf_ipc.py b/examples/offline_inference/weight_syncing/rlhf_ipc.py deleted file mode 100644 index 30167de343f0..000000000000 --- a/examples/offline_inference/weight_syncing/rlhf_ipc.py +++ /dev/null @@ -1,186 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray, -with new weight syncing APIs - -The script colocates the training and inference workloads onto the same GPU using Ray. - -The example performs the following steps: - -* Request a placement group of 1 GPU. -* Place the inference model on the above GPU using the placement group. -* Place and load the training model on the same GPU using the placement group. -* Generate text from a list of prompts using the inference engine. -* Update the weights of the training model and broadcast the updated weights - to the inference engine by using a Ray collective RPC group. Note that - for demonstration purposes we simply zero out the weights. - -This example assumes a single-node cluster with a single GPUs, -but can be extended to multiple GPUs. -""" - -import os -from dataclasses import asdict - -import ray -import torch -from ray.util.placement_group import placement_group -from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -from transformers import AutoModelForCausalLM - -from vllm import LLM, SamplingParams -from vllm.config import WeightTransferConfig -from vllm.distributed.weight_transfer.base import ( - WeightTransferInitRequest, - WeightUpdateRequest, -) -from vllm.distributed.weight_transfer.ipc_engine import IPCUpdateInfo - - -class MyLLM(LLM): - """Configure the vLLM worker for Ray placement group execution.""" - - def __init__(self, *args, **kwargs): - # Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray - # so that vLLM can manage its own device placement within the worker. - os.environ.pop("CUDA_VISIBLE_DEVICES", None) - # Each worker uses 0.4 GPU so that two instances fit on the same GPUs. - os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4" - os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0" - # needed for ipc handle serialization - os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" - super().__init__(*args, **kwargs) - - -def get_physical_gpu_id(): - device = torch.cuda.current_device() - props = torch.cuda.get_device_properties(device) - return str(props.uuid) - - -# Load the OPT-125M model onto GPU 0 for the training workload. - -MODEL_NAME = "facebook/opt-125m" - - -@ray.remote -class TrainModel: - def __init__(self, llm_handle: ray.ObjectRef): - self.train_model = AutoModelForCausalLM.from_pretrained( - MODEL_NAME, - ) - self.train_model.to("cuda:0") - self.llm_handle = llm_handle - - def init_weight_transfer(self): - self.llm_handle.init_weight_transfer.remote(WeightTransferInitRequest()) - - def broadcast_weights(self, llm_handle: ray.ObjectRef): - self.llm_handle = llm_handle - names, dtypes, shapes, ipc_handles = [], [], [], [] - - for name, p in self.train_model.named_parameters(): - names.append(name) - dtypes.append(str(p.dtype).split(".")[-1]) - shapes.append(p.shape) - - from torch.multiprocessing.reductions import reduce_tensor - - weight = p.detach().contiguous() - ipc_handle = reduce_tensor(weight) - ipc_handle = {get_physical_gpu_id(): ipc_handle} - ipc_handles.append(ipc_handle) - - ray.get( - self.llm_handle.update_weights.remote( - WeightUpdateRequest( - update_info=asdict( - IPCUpdateInfo( - names=names, - dtype_names=dtypes, - shapes=shapes, - ipc_handles=ipc_handles, - ) - ) - ) - ) - ) - - def zero_data(self): - # Simulate a training step by zeroing out all model weights. - # In a real RLHF training loop the weights would be updated using the gradient - # from an RL objective such as PPO on a reward model. - for name, p in self.train_model.named_parameters(): - p.data.zero_() - - -ray.init(runtime_env={"excludes": [".git/objects/pack/"]}) - -pg_colocate = placement_group([{"GPU": 1, "CPU": 0}]) -ray.get(pg_colocate.ready()) - - -llm = ray.remote( - num_cpus=0, - num_gpus=0, - scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg_colocate, - placement_group_capture_child_tasks=True, - ), -)(MyLLM).remote( - model=MODEL_NAME, - enforce_eager=True, - tensor_parallel_size=1, - distributed_executor_backend="ray", - gpu_memory_utilization=0.7, - weight_transfer_config=WeightTransferConfig(backend="ipc"), -) - -train_model = TrainModel.options( - num_gpus=0.1, - num_cpus=0, - scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg_colocate, placement_group_capture_child_tasks=True - ), -).remote(llm) - - -# Generate text from the prompts. -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] - -sampling_params = SamplingParams(temperature=0) - -outputs = ray.get(llm.generate.remote(prompts, sampling_params)) - -print("-" * 50) -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") - print("-" * 50) - -ray.get(train_model.init_weight_transfer.remote()) - -train_model.zero_data.remote() - -# Synchronize the updated weights to the inference engine using batched API. -ray.get(train_model.broadcast_weights.remote(llm)) - -# Finalize the weight update (processes weights for quantization/kernel format) -ray.get(llm.finalize_weight_update.remote()) - -# Generate text with the updated model. The output is expected to be nonsense -# because the weights are zero. -outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params)) -print("-" * 50) -for output in outputs_updated: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") - print("-" * 50) diff --git a/vllm/config/weight_transfer.py b/vllm/config/weight_transfer.py index 06cd0d68aed7..7ccac13fbfaa 100644 --- a/vllm/config/weight_transfer.py +++ b/vllm/config/weight_transfer.py @@ -11,5 +11,5 @@ class WeightTransferConfig: """Configuration for weight transfer during RL training.""" - backend: Literal["nccl", "ipc"] = "nccl" + backend: Literal["nccl"] = "nccl" """The backend to use for weight transfer.""" diff --git a/vllm/distributed/weight_transfer/__init__.py b/vllm/distributed/weight_transfer/__init__.py index 4f7ec6f4975b..82b853428327 100644 --- a/vllm/distributed/weight_transfer/__init__.py +++ b/vllm/distributed/weight_transfer/__init__.py @@ -11,16 +11,12 @@ WeightTransferEngine, WeightUpdateRequest, ) -from vllm.distributed.weight_transfer.ipc_engine import ( - IPCWeightTransferEngine, -) from vllm.distributed.weight_transfer.nccl_engine import ( NCCLWeightTransferEngine, ) WEIGHT_TRANSFER_ENGINE_REGISTRY = { "nccl": NCCLWeightTransferEngine, - "ipc": IPCWeightTransferEngine, } @@ -45,6 +41,5 @@ def init_transfer_engine(config: WeightTransferConfig, parallel_config: Parallel "NCCLWeightTransferEngine", "register_weight_transfer_engine", "WEIGHT_TRANSFER_ENGINE_REGISTRY", - "IPCWeightTransferEngine", "WeightUpdateRequest", ] diff --git a/vllm/distributed/weight_transfer/base.py b/vllm/distributed/weight_transfer/base.py index 9251405939db..97ae60e4ab46 100644 --- a/vllm/distributed/weight_transfer/base.py +++ b/vllm/distributed/weight_transfer/base.py @@ -52,7 +52,7 @@ class WeightTransferEngine(ABC, Generic[TInitInfo, TUpdateInfo]): from a trainer to inference workers. This abstraction separates weight transfer transport logic from the worker - implementation, allowing different backends (NCCL, CUDA IPC, RDMA) to be + implementation, allowing different backends (NCCL, CUDA IPC[TODO], RDMA[TODO]) to be plugged in. Subclasses should define: @@ -139,7 +139,7 @@ def receive_weights( Args: update_info: Backend-specific update info containing parameter metadata - and any backend-specific data (e.g., IPC handles) + and any backend-specific data load_weights: Callable that loads weights into the model. Called incrementally for each weight to avoid OOM. """ diff --git a/vllm/distributed/weight_transfer/ipc_engine.py b/vllm/distributed/weight_transfer/ipc_engine.py deleted file mode 100644 index 758b7c06df53..000000000000 --- a/vllm/distributed/weight_transfer/ipc_engine.py +++ /dev/null @@ -1,129 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable -from dataclasses import dataclass - -import torch - -from vllm.config.parallel import ParallelConfig -from vllm.config.weight_transfer import WeightTransferConfig -from vllm.distributed.weight_transfer.base import ( - BackendInitInfo, - BackendUpdateInfo, - WeightTransferEngine, -) - - -@dataclass -class IPCInitInfo(BackendInitInfo): - """Initialization info for IPC weight transfer backend. No init needed for IPC.""" - - pass - - -@dataclass -class IPCUpdateInfo(BackendUpdateInfo): - """Update info for IPC weight transfer backend.""" - - names: list[str] - dtype_names: list[str] - shapes: list[list[int]] - ipc_handles: list[dict[str, tuple[Callable, tuple]]] - - def __post_init__(self): - """Validate that all lists have the same length.""" - num_params = len(self.names) - if len(self.dtype_names) != num_params: - raise ValueError( - f"`dtype_names` should be of the same size as `names`: " - f"got {len(self.dtype_names)} and {len(self.names)}" - ) - if len(self.shapes) != num_params: - raise ValueError( - f"`shapes` should be of the same size as `names`: " - f"got {len(self.shapes)} and {len(self.names)}" - ) - if len(self.ipc_handles) != num_params: - raise ValueError( - f"`ipc_handles` should be of the same size as `names`: " - f"got {len(self.ipc_handles)} and {len(self.names)}" - ) - - -class IPCWeightTransferEngine(WeightTransferEngine[IPCInitInfo, IPCUpdateInfo]): - """ - Weight transfer engine using CUDA IPC for communication between trainer and workers. - - This implementation uses CUDA IPC to transfer weights from the trainer (rank 0) - to all inference workers in a process group. - """ - - # Define backend-specific dataclass types - init_info_cls = IPCInitInfo - update_info_cls = IPCUpdateInfo - - def __init__( - self, config: WeightTransferConfig, parallel_config: ParallelConfig - ) -> None: - """ - Initialize the weight transfer engine. - - Args: - config: The configuration for the weight transfer engine - parallel_config: The configuration for the parallel setup - """ - super().__init__(config, parallel_config) - - def init_transfer(self, init_info: IPCInitInfo) -> None: - """ - Initialize the weight transfer mechanism. - This is called once at the beginning of training. - No initialization needed for IPC backend. - - Args: - init_info: IPC initialization info (empty) - """ - pass - - def receive_weights( - self, - update_info: IPCUpdateInfo, - load_weights: Callable[[list[tuple[str, torch.Tensor]]], None], - ) -> None: - """ - Receive weights from the trainer via CUDA IPC handles. - - Args: - update_info: IPC update info containing parameter names, dtypes, shapes, - and IPC handles. Each IPC handle is a mapping between physical - GPU UUID and the IPC handle tuple (func, args). - - Returns: - List of (name, weight_tensor) tuples ready to be loaded into the model - """ - weights = [] - for name, _dtype_name, _shape, ipc_handle in zip( - update_info.names, - update_info.dtype_names, - update_info.shapes, - update_info.ipc_handles, - ): - device_index = torch.cuda.current_device() - props = torch.cuda.get_device_properties() - physical_gpu_id = str(props.uuid) - - handle = ipc_handle[physical_gpu_id] - - func, args = handle - list_args = list(args) # type: ignore - list_args[6] = device_index - weight = func(*list_args) # type: ignore - weights.append((name, weight)) - - load_weights(weights) - - def shutdown(self) -> None: - """ - Shutdown the weight transfer engine. - """ - pass diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8a180cd5906c..ebbb3e781e8a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -580,7 +580,7 @@ class EngineArgs: tokens_only: bool = False weight_transfer_backend: str | None = None - """Backend for weight transfer during RL training. Options: nccl, ipc""" + """Backend for weight transfer during RL training. Options: nccl""" weight_transfer_config: WeightTransferConfig = field( default_factory=WeightTransferConfig @@ -1195,10 +1195,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: vllm_group.add_argument( "--weight-transfer-backend", type=str, - choices=["nccl", "ipc"], + choices=["nccl"], default="nccl", help="Backend for weight transfer during RL training. " - "Options: nccl (distributed), ipc (same-node shared memory)" + "Options: nccl (distributed)" "Default: nccl when enabled.", ) From 892b736746d9a7ae7ad1cd47717bbaf607cb1d71 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Mon, 26 Jan 2026 12:49:29 -0800 Subject: [PATCH 18/36] edit examples to start with random weights, then weight sync to trained weights Signed-off-by: ahao-anyscale --- .../new_weight_syncing/rlhf.py | 31 ++++++------------- .../new_weight_syncing/rlhf_async_new_apis.py | 2 +- .../new_weight_syncing/rlhf_http.py | 31 ++++++++----------- 3 files changed, 24 insertions(+), 40 deletions(-) diff --git a/examples/offline_inference/new_weight_syncing/rlhf.py b/examples/offline_inference/new_weight_syncing/rlhf.py index b322ba24ddf0..e9c60751a00e 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf.py +++ b/examples/offline_inference/new_weight_syncing/rlhf.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray, -with new weight syncing APIs. +with native weight syncing APIs at engine instance. The script separates training and inference workloads onto distinct GPUs so that Ray can manage process placement and inter-process communication. @@ -13,11 +13,10 @@ * Load the training model on GPU 0. * Split the inference model across GPUs 1–2 using vLLM's tensor parallelism - and Ray placement groups. + and Ray placement groups with dummy weights. * Generate text from a list of prompts using the inference engine. * Update the weights of the training model and broadcast the updated weights - to the inference engine by using a Ray collective RPC group. Note that - for demonstration purposes we simply zero out the weights. + to the inference engine by using a Ray collective RPC group. This example assumes a single-node cluster with three GPUs, but Ray supports multi-node clusters. vLLM expects the GPUs are only used for vLLM @@ -55,9 +54,7 @@ class MyLLM(LLM): """Configure the vLLM worker for Ray placement group execution.""" def __init__(self, *args, **kwargs): - # Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray - # so that vLLM can manage its own device placement within the worker. - os.environ.pop("CUDA_VISIBLE_DEVICES", None) + os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0,1" super().__init__(*args, **kwargs) @@ -73,11 +70,6 @@ def __init__(self, model_name: str): ) self.model.to(self.device) - def zero_weights(self): - """Zero out all model weights (simulates training step).""" - for name, p in self.model.named_parameters(): - p.data.zero_() - def get_weight_metadata(self): """Return weight names, dtypes, and shapes for weight transfer.""" names = [] @@ -140,7 +132,7 @@ def broadcast_weights(self, packed: bool = True): data_parallel_size=1, distributed_executor_backend="ray", weight_transfer_config=WeightTransferConfig(backend="nccl"), - # enable_expert_parallel=True, + load_format="dummy", ) # Generate text from the prompts. @@ -155,6 +147,8 @@ def broadcast_weights(self, packed: bool = True): outputs = ray.get(llm.generate.remote(prompts, sampling_params)) +# Generate text with the initial model. The output is expected to be nonsense +# because the weights are randomly initialized. print("-" * 50) for output in outputs: prompt = output.prompt @@ -187,11 +181,6 @@ def broadcast_weights(self, packed: bool = True): ) ray.get([train_handle, inference_handle]) -# Simulate a training step by zeroing out all model weights. -# In a real RLHF training loop the weights would be updated using the gradient -# from an RL objective such as PPO on a reward model. -ray.get(train_model.zero_weights.remote()) - # Synchronize the updated weights to the inference engine using batched API. # Collect all weight metadata from the training actor names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote()) @@ -215,11 +204,11 @@ def broadcast_weights(self, packed: bool = True): train_handle = train_model.broadcast_weights.remote(packed=True) ray.get([train_handle, inference_handle]) -# Finalize the weight update (processes weights for quantization/kernel format) +# Finalize the weight update ray.get(llm.finalize_weight_update.remote()) -# Generate text with the updated model. The output is expected to be nonsense -# because the weights are zero. +# Generate text with the updated model. The output is expected to be normal +# because the weights are updated. outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params)) print("-" * 50) for output in outputs_updated: diff --git a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py index 90e1ef1457bc..b6799fe7ed54 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py +++ b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py @@ -103,6 +103,7 @@ class MyLLM: """Simple wrapper over AsyncLLM for supporting async RL.""" def __init__(self, **kwargs): + os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0,1" self.engine = vllm.AsyncLLMEngine.from_engine_args( vllm.AsyncEngineArgs(**kwargs) ) @@ -162,7 +163,6 @@ async def finalize_weight_update(self) -> None: # be placed on GPUs 1 and 2. os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2" ray.init(runtime_env={"excludes": [".git/objects/pack/"]}) -# ray.init() # Launch the training model actor. Ray's resource scheduler will allocate # 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs. diff --git a/examples/offline_inference/new_weight_syncing/rlhf_http.py b/examples/offline_inference/new_weight_syncing/rlhf_http.py index 2c402ecb8ce6..e775621364b5 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf_http.py +++ b/examples/offline_inference/new_weight_syncing/rlhf_http.py @@ -15,7 +15,8 @@ $ vllm serve facebook/opt-125m \ --enforce-eager \ - --weight-transfer-backend nccl + --weight-transfer-backend nccl \ + --load-format dummy Then run this script: @@ -24,12 +25,12 @@ The example performs the following steps: * Load the training model on GPU 0. -* Generate text using the vLLM server via OpenAI-compatible API. +* Generate text using the vLLM server via OpenAI-compatible API. The output + is expected to be nonsense because the server is initialized with dummy weights. * Initialize weight transfer via HTTP endpoint. -* Update the weights of the training model and broadcast the updated weights - to the vLLM server using NCCL. Note that for demonstration purposes we - simply zero out the weights. -* Generate text again to show the effect of the weight update. +* Broadcast the real weights from the training model to the vLLM server + using NCCL. +* Generate text again to show normal output after the weight update. """ from dataclasses import asdict @@ -152,9 +153,10 @@ def main(): "The future of AI is", ] - # Generate text before weight update + # Generate text before weight update. The output is expected to be nonsense + # because the server is initialized with dummy weights. print("-" * 50) - print("Generating text BEFORE weight update:") + print("Generating text BEFORE weight update (expect nonsense):") print("-" * 50) outputs = generate_completions(client, MODEL_NAME, prompts) for prompt, generated_text in zip(prompts, outputs): @@ -187,13 +189,6 @@ def main(): # Wait for init_weight_transfer to complete init_thread.join() - # Simulate a training step by zeroing out all model weights. - # In a real RLHF training loop the weights would be updated using the - # gradient from an RL objective such as PPO on a reward model. - print("Simulating training step (zeroing out weights)...") - for name, p in train_model.named_parameters(): - p.data.zero_() - # Collect weight metadata for the update request names = [] dtype_names = [] @@ -226,10 +221,10 @@ def main(): # Finalize the weight update (processes weights for quantization/kernel format) finalize_weight_update(BASE_URL) - # Generate text after weight update. The output is expected to be nonsense - # because the weights are zero. + # Generate text after weight update. The output is expected to be normal + # because the real weights are now loaded. print("-" * 50) - print("Generating text AFTER weight update (expect nonsense):") + print("Generating text AFTER weight update:") print("-" * 50) outputs_updated = generate_completions(client, MODEL_NAME, prompts) for prompt, generated_text in zip(prompts, outputs_updated): From 27a1441d8fcfc5793723f37f6c9c7be14bd8dec9 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Mon, 26 Jan 2026 15:07:14 -0800 Subject: [PATCH 19/36] added weight transfer factory Signed-off-by: ahao-anyscale --- .../new_weight_syncing/rlhf.py | 38 ++-- .../new_weight_syncing/rlhf_utils.py | 20 --- .../rlhf_http.py | 51 +++--- tests/distributed/test_weight_transfer.py | 89 ++-------- .../test_weight_transfer_llm.py | 20 +-- vllm/distributed/weight_transfer/__init__.py | 37 +--- vllm/distributed/weight_transfer/factory.py | 164 ++++++++++++++++++ .../weight_transfer/nccl_engine.py | 38 ++-- vllm/entrypoints/llm.py | 20 +-- vllm/entrypoints/serve/rlhf/api_router.py | 6 +- vllm/v1/worker/gpu_worker.py | 7 +- 11 files changed, 273 insertions(+), 217 deletions(-) delete mode 100644 examples/offline_inference/new_weight_syncing/rlhf_utils.py rename examples/{offline_inference/new_weight_syncing => online_serving}/rlhf_http.py (87%) create mode 100644 vllm/distributed/weight_transfer/factory.py diff --git a/examples/offline_inference/new_weight_syncing/rlhf.py b/examples/offline_inference/new_weight_syncing/rlhf.py index e9c60751a00e..bb58f16ac96b 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf.py +++ b/examples/offline_inference/new_weight_syncing/rlhf.py @@ -25,24 +25,16 @@ """ import os -from dataclasses import asdict import ray import torch from ray.util.placement_group import placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -from rlhf_utils import stateless_init_process_group from transformers import AutoModelForCausalLM from vllm import LLM, SamplingParams from vllm.config import WeightTransferConfig -from vllm.distributed.weight_transfer.base import ( - WeightTransferInitRequest, - WeightUpdateRequest, -) from vllm.distributed.weight_transfer.nccl_engine import ( - NCCLInitInfo, - NCCLUpdateInfo, NCCLWeightTransferEngine, ) from vllm.utils.network_utils import get_ip, get_open_port @@ -83,7 +75,7 @@ def get_weight_metadata(self): def init_weight_transfer_group(self, master_address, master_port, world_size): """Initialize the NCCL process group for weight transfer.""" - self.model_update_group = stateless_init_process_group( + self.model_update_group = NCCLWeightTransferEngine.stateless_init_process_group( master_address, master_port, 0, world_size, self.device ) @@ -163,14 +155,12 @@ def broadcast_weights(self, packed: bool = True): world_size = ray.get(llm.get_world_size.remote()) + 1 # +1 for the trainer inference_handle = llm.init_weight_transfer.remote( - WeightTransferInitRequest( - init_info=asdict( - NCCLInitInfo( - master_address=master_address, - master_port=master_port, - rank_offset=1, - world_size=world_size, - ) + dict( + init_info=dict( + master_address=master_address, + master_port=master_port, + rank_offset=1, + world_size=world_size, ) ) ) @@ -188,14 +178,12 @@ def broadcast_weights(self, packed: bool = True): # Issue update_weights call with NCCL-specific update info # packed=True enables efficient batched tensor broadcasting inference_handle = llm.update_weights.remote( - WeightUpdateRequest( - update_info=asdict( - NCCLUpdateInfo( - names=names, - dtype_names=dtype_names, - shapes=shapes, - packed=True, - ) + dict( + update_info=dict( + names=names, + dtype_names=dtype_names, + shapes=shapes, + packed=True, ) ) ) diff --git a/examples/offline_inference/new_weight_syncing/rlhf_utils.py b/examples/offline_inference/new_weight_syncing/rlhf_utils.py deleted file mode 100644 index 35761ae3996d..000000000000 --- a/examples/offline_inference/new_weight_syncing/rlhf_utils.py +++ /dev/null @@ -1,20 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - - -def stateless_init_process_group(master_address, master_port, rank, world_size, device): - """ - vLLM provides `StatelessProcessGroup` to create a process group - without considering the global process group in torch.distributed. - It is recommended to create `StatelessProcessGroup`, and then initialize - the data-plane communication (NCCL) between external (train processes) - and vLLM workers. - """ - from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator - from vllm.distributed.utils import StatelessProcessGroup - - pg = StatelessProcessGroup.create( - host=master_address, port=master_port, rank=rank, world_size=world_size - ) - pynccl = PyNcclCommunicator(pg, device=device) - return pynccl diff --git a/examples/offline_inference/new_weight_syncing/rlhf_http.py b/examples/online_serving/rlhf_http.py similarity index 87% rename from examples/offline_inference/new_weight_syncing/rlhf_http.py rename to examples/online_serving/rlhf_http.py index e775621364b5..7e1f237865e3 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf_http.py +++ b/examples/online_serving/rlhf_http.py @@ -33,17 +33,12 @@ * Generate text again to show normal output after the weight update. """ -from dataclasses import asdict - import requests import torch from openai import OpenAI -from rlhf_utils import stateless_init_process_group from transformers import AutoModelForCausalLM from vllm.distributed.weight_transfer.nccl_engine import ( - NCCLInitInfo, - NCCLUpdateInfo, NCCLWeightTransferEngine, ) from vllm.utils.network_utils import get_ip, get_open_port @@ -76,13 +71,11 @@ def init_weight_transfer( """Initialize weight transfer via HTTP endpoint.""" url = f"{base_url}/init_weight_transfer" payload = { - "init_info": asdict( - NCCLInitInfo( - master_address=master_address, - master_port=master_port, - rank_offset=rank_offset, - world_size=world_size, - ) + "init_info": dict( + master_address=master_address, + master_port=master_port, + rank_offset=rank_offset, + world_size=world_size, ) } response = requests.post(url, json=payload, timeout=60) @@ -99,13 +92,11 @@ def update_weights( """Update weights via HTTP endpoint.""" url = f"{base_url}/update_weights" payload = { - "update_info": asdict( - NCCLUpdateInfo( - names=names, - dtype_names=dtype_names, - shapes=shapes, - packed=packed, - ) + "update_info": dict( + names=names, + dtype_names=dtype_names, + shapes=shapes, + packed=packed, ) } response = requests.post(url, json=payload, timeout=300) @@ -119,6 +110,20 @@ def finalize_weight_update(base_url: str) -> None: response.raise_for_status() +def pause_generation(base_url: str) -> None: + """Pause generation via HTTP endpoint.""" + url = f"{base_url}/pause" + response = requests.post(url, timeout=60) + response.raise_for_status() + + +def resume_generation(base_url: str) -> None: + """Resume generation via HTTP endpoint.""" + url = f"{base_url}/resume" + response = requests.post(url, timeout=60) + response.raise_for_status() + + def get_world_size(base_url: str) -> int: """Get world size from the vLLM server.""" url = f"{base_url}/get_world_size" @@ -182,13 +187,16 @@ def main(): init_thread.start() # Initialize NCCL process group on trainer side - model_update_group = stateless_init_process_group( + model_update_group = NCCLWeightTransferEngine.stateless_init_process_group( master_address, master_port, 0, world_size, torch.device(device) ) # Wait for init_weight_transfer to complete init_thread.join() + # Pause generation before weight sync + pause_generation(BASE_URL) + # Collect weight metadata for the update request names = [] dtype_names = [] @@ -221,6 +229,9 @@ def main(): # Finalize the weight update (processes weights for quantization/kernel format) finalize_weight_update(BASE_URL) + # Resume generation after weight sync + resume_generation(BASE_URL) + # Generate text after weight update. The output is expected to be normal # because the real weights are now loaded. print("-" * 50) diff --git a/tests/distributed/test_weight_transfer.py b/tests/distributed/test_weight_transfer.py index 8fb33723fd76..b98882cf4934 100644 --- a/tests/distributed/test_weight_transfer.py +++ b/tests/distributed/test_weight_transfer.py @@ -15,20 +15,11 @@ from vllm.config.parallel import ParallelConfig from vllm.config.weight_transfer import WeightTransferConfig -from vllm.distributed.weight_transfer import ( - WEIGHT_TRANSFER_ENGINE_REGISTRY, - init_transfer_engine, - register_weight_transfer_engine, -) +from vllm.distributed.weight_transfer import WeightTransferEngineFactory from vllm.distributed.weight_transfer.base import ( BackendInitInfo, WeightTransferEngine, ) -from vllm.distributed.weight_transfer.ipc_engine import ( - IPCInitInfo, - IPCUpdateInfo, - IPCWeightTransferEngine, -) from vllm.distributed.weight_transfer.nccl_engine import ( NCCLInitInfo, NCCLUpdateInfo, @@ -95,33 +86,6 @@ def test_empty_lists_valid(self): assert len(info.names) == 0 -# --- Unit Tests: IPCUpdateInfo Validation --- - - -class TestIPCUpdateInfoValidation: - """Test IPCUpdateInfo dataclass validation.""" - - def test_valid_update_info(self): - """Test creating valid IPCUpdateInfo.""" - info = IPCUpdateInfo( - names=["layer.weight"], - dtype_names=["float32"], - shapes=[[10, 10]], - ipc_handles=[{"gpu-uuid": (lambda: None, ())}], - ) - assert info.names == ["layer.weight"] - - def test_mismatched_ipc_handles_raises(self): - """Test that mismatched ipc_handles length raises ValueError.""" - with pytest.raises(ValueError, match="ipc_handles"): - IPCUpdateInfo( - names=["layer.weight", "layer.bias"], - dtype_names=["float32", "float32"], - shapes=[[10, 10], [10]], - ipc_handles=[{}], # Only one handle - ) - - # --- Unit Tests: Engine Parsing --- @@ -183,54 +147,25 @@ def test_parse_update_info_valid(self): assert update_info.shapes == [[100, 100], [50]] -class TestIPCEngineParsing: - """Test IPCWeightTransferEngine parsing methods.""" - - def test_parse_init_info_empty(self): - """Test parsing empty init info (IPC doesn't need init params).""" - config = WeightTransferConfig(backend="ipc") - parallel_config = create_mock_parallel_config() - engine = IPCWeightTransferEngine(config, parallel_config) - - init_info = engine.parse_init_info({}) - assert isinstance(init_info, IPCInitInfo) - - def test_init_transfer_is_noop(self): - """Test that IPC init_transfer is a no-op.""" - config = WeightTransferConfig(backend="ipc") - parallel_config = create_mock_parallel_config() - engine = IPCWeightTransferEngine(config, parallel_config) - - # Should not raise - engine.init_transfer(IPCInitInfo()) - - # --- Unit Tests: Engine Registry --- class TestEngineRegistry: """Test weight transfer engine registry.""" - def test_init_transfer_engine_nccl(self): - """Test init_transfer_engine creates NCCL engine.""" + def test_create_engine_nccl(self): + """Test factory creates NCCL engine.""" config = WeightTransferConfig(backend="nccl") parallel_config = create_mock_parallel_config() - engine = init_transfer_engine(config, parallel_config) + engine = WeightTransferEngineFactory.create_engine(config, parallel_config) assert isinstance(engine, NCCLWeightTransferEngine) - def test_init_transfer_engine_ipc(self): - """Test init_transfer_engine creates IPC engine.""" - config = WeightTransferConfig(backend="ipc") - parallel_config = create_mock_parallel_config() - engine = init_transfer_engine(config, parallel_config) - assert isinstance(engine, IPCWeightTransferEngine) - - def test_init_transfer_engine_invalid_backend(self): - """Test init_transfer_engine raises for invalid backend.""" + def test_create_engine_invalid_backend(self): + """Test factory raises for invalid backend.""" config = WeightTransferConfig(backend="invalid") parallel_config = create_mock_parallel_config() with pytest.raises(ValueError, match="Invalid weight transfer backend"): - init_transfer_engine(config, parallel_config) + WeightTransferEngineFactory.create_engine(config, parallel_config) def test_register_custom_engine(self): """Test registering a custom engine.""" @@ -253,16 +188,18 @@ def shutdown(self): pass # Register custom engine - register_weight_transfer_engine("custom_test", CustomEngine) - assert "custom_test" in WEIGHT_TRANSFER_ENGINE_REGISTRY + WeightTransferEngineFactory.register_engine("custom_test", CustomEngine) + assert WeightTransferEngineFactory.is_registered("custom_test") # Clean up - del WEIGHT_TRANSFER_ENGINE_REGISTRY["custom_test"] + WeightTransferEngineFactory.unregister_engine("custom_test") def test_register_duplicate_raises(self): """Test registering duplicate engine name raises.""" with pytest.raises(ValueError, match="already registered"): - register_weight_transfer_engine("nccl", NCCLWeightTransferEngine) + WeightTransferEngineFactory.register_engine( + "nccl", NCCLWeightTransferEngine + ) # --- Test receive_weights without init raises --- diff --git a/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py b/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py index 22505bbdea51..1c0ac3bbe258 100644 --- a/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py +++ b/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py @@ -91,8 +91,8 @@ def shutdown(self) -> None: MockWeightTransferEngine.shutdown_called = True -def mock_init_transfer_engine(config, parallel_config): - """Factory function that returns our mock engine.""" +def mock_create_engine(config, parallel_config): + """Mock factory function that returns our mock engine.""" return MockWeightTransferEngine(config, parallel_config) @@ -127,8 +127,8 @@ def test_init_weight_transfer_calls_engine(): os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" with patch( - "vllm.distributed.weight_transfer.init_transfer_engine", - mock_init_transfer_engine, + "vllm.v1.worker.gpu_worker.WeightTransferEngineFactory.create_engine", + mock_create_engine, ): llm = LLM( model=MODEL_NAME, @@ -174,8 +174,8 @@ def test_update_weights_calls_engine(): os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" with patch( - "vllm.distributed.weight_transfer.init_transfer_engine", - mock_init_transfer_engine, + "vllm.v1.worker.gpu_worker.WeightTransferEngineFactory.create_engine", + mock_create_engine, ): llm = LLM( model=MODEL_NAME, @@ -228,8 +228,8 @@ def test_finalize_weight_update_runs(): pytest.skip("Need at least 1 GPU for this test") with patch( - "vllm.distributed.weight_transfer.init_transfer_engine", - mock_init_transfer_engine, + "vllm.v1.worker.gpu_worker.WeightTransferEngineFactory.create_engine", + mock_create_engine, ): llm = LLM( model=MODEL_NAME, @@ -254,8 +254,8 @@ def test_full_weight_transfer_flow(): os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" with patch( - "vllm.distributed.weight_transfer.init_transfer_engine", - mock_init_transfer_engine, + "vllm.v1.worker.gpu_worker.WeightTransferEngineFactory.create_engine", + mock_create_engine, ): llm = LLM( model=MODEL_NAME, diff --git a/vllm/distributed/weight_transfer/__init__.py b/vllm/distributed/weight_transfer/__init__.py index 82b853428327..c96ad0e3bb4f 100644 --- a/vllm/distributed/weight_transfer/__init__.py +++ b/vllm/distributed/weight_transfer/__init__.py @@ -5,41 +5,8 @@ to inference workers. """ -from vllm.config.parallel import ParallelConfig -from vllm.config.weight_transfer import WeightTransferConfig -from vllm.distributed.weight_transfer.base import ( - WeightTransferEngine, - WeightUpdateRequest, -) -from vllm.distributed.weight_transfer.nccl_engine import ( - NCCLWeightTransferEngine, -) - -WEIGHT_TRANSFER_ENGINE_REGISTRY = { - "nccl": NCCLWeightTransferEngine, -} - - -def register_weight_transfer_engine( - name: str, engine: type[WeightTransferEngine] -) -> None: - if name in WEIGHT_TRANSFER_ENGINE_REGISTRY: - raise ValueError(f"Weight transfer engine {name} already registered") - WEIGHT_TRANSFER_ENGINE_REGISTRY[name] = engine - - -def init_transfer_engine(config: WeightTransferConfig, parallel_config: ParallelConfig): - if config.backend not in WEIGHT_TRANSFER_ENGINE_REGISTRY: - raise ValueError(f"Invalid weight transfer backend: {config.backend}") - - engine_cls = WEIGHT_TRANSFER_ENGINE_REGISTRY[config.backend] - return engine_cls(config, parallel_config) - +from vllm.distributed.weight_transfer.factory import WeightTransferEngineFactory __all__ = [ - "WeightTransferEngine", - "NCCLWeightTransferEngine", - "register_weight_transfer_engine", - "WEIGHT_TRANSFER_ENGINE_REGISTRY", - "WeightUpdateRequest", + "WeightTransferEngineFactory", ] diff --git a/vllm/distributed/weight_transfer/factory.py b/vllm/distributed/weight_transfer/factory.py new file mode 100644 index 000000000000..5f24b0062141 --- /dev/null +++ b/vllm/distributed/weight_transfer/factory.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Factory for weight transfer engines with lazy loading.""" + +import importlib +from collections.abc import Callable +from typing import TYPE_CHECKING + +from vllm.distributed.weight_transfer.base import WeightTransferEngine +from vllm.logger import init_logger + +if TYPE_CHECKING: + from vllm.config.parallel import ParallelConfig + from vllm.config.weight_transfer import WeightTransferConfig + +logger = init_logger(__name__) + + +class WeightTransferEngineFactory: + """Factory for creating weight transfer engines with lazy loading. + + This factory implements a registry pattern that supports: + - Lazy loading: Engine modules are only imported when actually needed + - Extensibility: Custom engines can be registered at runtime + - Centralized registration: All built-in engines registered in one place + """ + + _registry: dict[str, Callable[[], type[WeightTransferEngine]]] = {} + + @classmethod + def register_engine( + cls, + name: str, + module_path_or_cls: str | type[WeightTransferEngine], + class_name: str | None = None, + ) -> None: + """Register an engine with lazy-loading or direct class reference. + + Supports two calling conventions: + 1. Lazy loading: register_engine(name, module_path, class_name) + 2. Direct class: register_engine(name, engine_cls) + + Args: + name: The name to register the engine under (e.g., "nccl") + module_path_or_cls: Either a module path string for lazy loading, + or the engine class directly + class_name: Name of the engine class (required if module_path is string) + + Raises: + ValueError: If an engine with the same name is already registered + """ + if name in cls._registry: + raise ValueError(f"Weight transfer engine '{name}' is already registered.") + + if isinstance(module_path_or_cls, str): + # Lazy loading path + module_path = module_path_or_cls + if class_name is None: + raise ValueError( + "class_name is required when registering with module path" + ) + + def loader() -> type[WeightTransferEngine]: + module = importlib.import_module(module_path) + return getattr(module, class_name) + + cls._registry[name] = loader + else: + # Direct class registration + engine_cls = module_path_or_cls + cls._registry[name] = lambda: engine_cls + + @classmethod + def create_engine( + cls, + config: "WeightTransferConfig", + parallel_config: "ParallelConfig", + ) -> WeightTransferEngine: + """Create a weight transfer engine instance. + + Args: + config: Weight transfer configuration containing the backend name + parallel_config: Parallel configuration for the engine + + Returns: + An initialized weight transfer engine instance + + Raises: + ValueError: If the backend is not registered + """ + backend = config.backend + engine_cls = cls.get_engine_class(backend) + + logger.info( + "Creating weight transfer engine: %s", + engine_cls.__name__, + ) + + return engine_cls(config, parallel_config) + + @classmethod + def get_engine_class(cls, backend: str) -> type[WeightTransferEngine]: + """Get a registered engine class by name. + + Args: + backend: Name of the registered engine + + Returns: + The engine class + + Raises: + ValueError: If the engine is not registered + """ + if backend not in cls._registry: + available = list(cls._registry.keys()) + raise ValueError( + f"Invalid weight transfer backend: {backend}. " + f"Available engines: {available}" + ) + return cls._registry[backend]() + + @classmethod + def list_engines(cls) -> list[str]: + """List all registered engine names. + + Returns: + List of registered engine names + """ + return list(cls._registry.keys()) + + @classmethod + def unregister_engine(cls, name: str) -> None: + """Unregister an engine by name. + + Args: + name: Name of the engine to unregister + + Raises: + KeyError: If the engine is not registered + """ + del cls._registry[name] + + @classmethod + def is_registered(cls, name: str) -> bool: + """Check if an engine is registered. + + Args: + name: Name of the engine to check + + Returns: + True if the engine is registered + """ + return name in cls._registry + + +# Register built-in weight transfer engines here. +# Registration should be centralized to ensure lazy loading - +# engine modules are only imported when actually used. + +WeightTransferEngineFactory.register_engine( + "nccl", + "vllm.distributed.weight_transfer.nccl_engine", + "NCCLWeightTransferEngine", +) diff --git a/vllm/distributed/weight_transfer/nccl_engine.py b/vllm/distributed/weight_transfer/nccl_engine.py index 09f7f01b34f9..44f5a49e04ff 100644 --- a/vllm/distributed/weight_transfer/nccl_engine.py +++ b/vllm/distributed/weight_transfer/nccl_engine.py @@ -90,8 +90,6 @@ def init_transfer(self, init_info: NCCLInitInfo) -> None: init_info: NCCL initialization info containing master address, port, rank offset, and world size """ - from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator - from vllm.distributed.utils import StatelessProcessGroup # Calculate the global rank in the trainer-worker process group # Must account for data parallel to get unique ranks across all workers @@ -103,16 +101,12 @@ def init_transfer(self, init_info: NCCLInitInfo) -> None: worker_rank = dp_rank * world_size_per_dp + tp_rank rank = worker_rank + init_info.rank_offset # Create stateless process group - pg = StatelessProcessGroup.create( - host=init_info.master_address, - port=init_info.master_port, - rank=rank, - world_size=init_info.world_size, - ) - - # Initialize NCCL communicator - self.model_update_group = PyNcclCommunicator( - pg, device=torch.cuda.current_device() + self.model_update_group = NCCLWeightTransferEngine.stateless_init_process_group( + init_info.master_address, + init_info.master_port, + rank, + init_info.world_size, + torch.cuda.current_device(), ) def receive_weights( @@ -227,3 +221,23 @@ def trainer_broadcast_weights( for item in iterator: tensor = post_iter_func(item) group.broadcast(tensor, src=src, stream=torch.cuda.current_stream()) + + @staticmethod + def stateless_init_process_group( + master_address, master_port, rank, world_size, device + ): + """ + vLLM provides `StatelessProcessGroup` to create a process group + without considering the global process group in torch.distributed. + It is recommended to create `StatelessProcessGroup`, and then initialize + the data-plane communication (NCCL) between external (train processes) + and vLLM workers. + """ + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.utils import StatelessProcessGroup + + pg = StatelessProcessGroup.create( + host=master_address, port=master_port, rank=rank, world_size=world_size + ) + pynccl = PyNcclCommunicator(pg, device=device) + return pynccl diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index ddf31f1266e4..486442f8135e 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1806,35 +1806,31 @@ def _run_engine( # its previous requests. return sorted(outputs, key=lambda x: int(x.request_id)) - def init_weight_transfer(self, request: WeightTransferInitRequest) -> None: + def init_weight_transfer(self, request: WeightTransferInitRequest | dict) -> None: """ Initialize weight transfer for RL training. Args: request: Weight transfer initialization request with backend-specific info """ - - if isinstance(request, WeightTransferInitRequest): - init_info_dict = request.init_info - else: - raise TypeError(f"Expected WeightTransferInitRequest, got {type(request)}") + init_info_dict = ( + request["init_info"] if isinstance(request, dict) else request.init_info + ) self.llm_engine.collective_rpc( "init_weight_transfer", kwargs={"init_info": init_info_dict} ) - def update_weights(self, request: WeightUpdateRequest) -> None: + def update_weights(self, request: WeightUpdateRequest | dict) -> None: """ Update the weights of the model. Args: request: Weight update request with backend-specific update info """ - - if isinstance(request, WeightUpdateRequest): - update_info_dict = request.update_info - else: - raise TypeError(f"Expected WeightUpdateRequest, got {type(request)}") + update_info_dict = ( + request["update_info"] if isinstance(request, dict) else request.update_info + ) self.llm_engine.collective_rpc( "update_weights", kwargs={"update_info": update_info_dict} diff --git a/vllm/entrypoints/serve/rlhf/api_router.py b/vllm/entrypoints/serve/rlhf/api_router.py index 04b17e9c3b34..6b70f8e72f59 100644 --- a/vllm/entrypoints/serve/rlhf/api_router.py +++ b/vllm/entrypoints/serve/rlhf/api_router.py @@ -7,8 +7,10 @@ from fastapi import APIRouter, FastAPI, HTTPException, Query, Request from fastapi.responses import JSONResponse -from vllm.distributed.weight_transfer import WeightUpdateRequest -from vllm.distributed.weight_transfer.base import WeightTransferInitRequest +from vllm.distributed.weight_transfer.base import ( + WeightTransferInitRequest, + WeightUpdateRequest, +) from vllm.engine.protocol import EngineClient from vllm.logger import init_logger diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 16b0b1f81768..77799e1a92b6 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -32,9 +32,7 @@ get_pp_group, get_tp_group, ) -from vllm.distributed.weight_transfer import ( - init_transfer_engine, -) +from vllm.distributed.weight_transfer import WeightTransferEngineFactory from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.models.interfaces import is_mixture_of_experts @@ -92,8 +90,7 @@ def __init__( self._sleep_saved_buffers: dict[str, torch.Tensor] = {} # Weight transfer engine (initialized on-demand) - # check if class is in the map - self.weight_transfer_engine = init_transfer_engine( + self.weight_transfer_engine = WeightTransferEngineFactory.create_engine( self.vllm_config.weight_transfer_config, self.vllm_config.parallel_config ) From bbc13e947e42dd39d5897119bcae3eabfe1eea1d Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Mon, 26 Jan 2026 17:37:20 -0800 Subject: [PATCH 20/36] change config Signed-off-by: ahao-anyscale --- examples/online_serving/rlhf_http.py | 2 +- vllm/config/vllm.py | 4 +-- vllm/engine/arg_utils.py | 38 +++++++++-------------- vllm/entrypoints/llm.py | 16 +++++++--- vllm/entrypoints/serve/rlhf/api_router.py | 21 ++++++++++--- vllm/v1/worker/gpu_worker.py | 19 ++++++++++-- 6 files changed, 61 insertions(+), 39 deletions(-) diff --git a/examples/online_serving/rlhf_http.py b/examples/online_serving/rlhf_http.py index 7e1f237865e3..77908d89fe72 100644 --- a/examples/online_serving/rlhf_http.py +++ b/examples/online_serving/rlhf_http.py @@ -15,7 +15,7 @@ $ vllm serve facebook/opt-125m \ --enforce-eager \ - --weight-transfer-backend nccl \ + --weight-transfer-config '{"backend": "nccl"}' \ --load-format dummy Then run this script: diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index c4d67f8275af..037f21950d17 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -243,9 +243,7 @@ class VllmConfig: performance. -02 is used by defult. See OptimizationLevel for full description.""" - weight_transfer_config: WeightTransferConfig = Field( - default_factory=WeightTransferConfig - ) + weight_transfer_config: WeightTransferConfig | None = None """The configurations for weight transfer during RL training.""" def compute_hash(self) -> str: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ebbb3e781e8a..a142f3171864 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -8,7 +8,7 @@ import json import sys from collections.abc import Callable -from dataclasses import MISSING, dataclass, field, fields, is_dataclass +from dataclasses import MISSING, dataclass, fields, is_dataclass from itertools import permutations from types import UnionType from typing import ( @@ -236,17 +236,17 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]: # Save time only getting attr docs if we're generating help text cls_docs = get_attr_docs(cls) if NEEDS_HELP else {} kwargs = {} - for fld in fields(cls): + for field in fields(cls): # Get the set of possible types for the field - type_hints: set[TypeHint] = get_type_hints(fld.type) + type_hints: set[TypeHint] = get_type_hints(field.type) # If the field is a dataclass, we can use the model_validate_json generator = (th for th in type_hints if is_dataclass(th)) dataclass_cls = next(generator, None) # Get the default value of the field - if fld.default is not MISSING: - default = fld.default + if field.default is not MISSING: + default = field.default # Handle pydantic.Field defaults if isinstance(default, FieldInfo): if default.default_factory is None: @@ -256,11 +256,11 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]: # These could emit logs on init, which would be confusing. with suppress_logging(): default = default.default_factory() - elif fld.default_factory is not MISSING: - default = fld.default_factory() + elif field.default_factory is not MISSING: + default = field.default_factory() # Get the help text for the field - name = fld.name + name = field.name help = cls_docs.get(name, "").strip() # Escape % for argparse help = help.replace("%", "%%") @@ -579,12 +579,10 @@ class EngineArgs: kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend tokens_only: bool = False - weight_transfer_backend: str | None = None - """Backend for weight transfer during RL training. Options: nccl""" - - weight_transfer_config: WeightTransferConfig = field( - default_factory=WeightTransferConfig - ) + weight_transfer_config: WeightTransferConfig | None = None + """Configuration for weight transfer during RL training. + Accepts a JSON string or dict with backend-specific options. + Example: '{"backend": "nccl"}'""" def __post_init__(self): # support `EngineArgs(compilation_config={...})` @@ -596,9 +594,9 @@ def __post_init__(self): self.attention_config = AttentionConfig(**self.attention_config) if isinstance(self.eplb_config, dict): self.eplb_config = EPLBConfig(**self.eplb_config) - if self.weight_transfer_backend is not None: + if isinstance(self.weight_transfer_config, dict): self.weight_transfer_config = WeightTransferConfig( - backend=self.weight_transfer_backend + **self.weight_transfer_config ) # Setup plugins from vllm.plugins import load_general_plugins @@ -1193,13 +1191,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--optimization-level", **vllm_kwargs["optimization_level"] ) vllm_group.add_argument( - "--weight-transfer-backend", - type=str, - choices=["nccl"], - default="nccl", - help="Backend for weight transfer during RL training. " - "Options: nccl (distributed)" - "Default: nccl when enabled.", + "--weight-transfer-config", **vllm_kwargs["weight_transfer_config"] ) # Other arguments diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 486442f8135e..2b651a256486 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -361,14 +361,22 @@ def _make_config(value: Any, cls: type[_R]) -> _R: def get_tokenizer(self) -> TokenizerLike: return self.llm_engine.get_tokenizer() - def get_world_size(self) -> int: + def get_world_size(self, include_dp: bool = True) -> int: """Get the world size from the parallel config. + Args: + include_dp: If True (default), returns the world size including + data parallelism (TP * PP * DP). If False, returns the world + size without data parallelism (TP * PP). + Returns: - The world size including data parallelism - (tensor_parallel_size * pipeline_parallel_size * data_parallel_size). + The world size (tensor_parallel_size * pipeline_parallel_size), + optionally multiplied by data_parallel_size if include_dp is True. """ - return self.llm_engine.vllm_config.parallel_config.world_size_across_dp + parallel_config = self.llm_engine.vllm_config.parallel_config + if include_dp: + return parallel_config.world_size_across_dp + return parallel_config.world_size def reset_mm_cache(self) -> None: self.input_processor.clear_mm_cache() diff --git a/vllm/entrypoints/serve/rlhf/api_router.py b/vllm/entrypoints/serve/rlhf/api_router.py index 6b70f8e72f59..2cbbac48ed93 100644 --- a/vllm/entrypoints/serve/rlhf/api_router.py +++ b/vllm/entrypoints/serve/rlhf/api_router.py @@ -145,11 +145,22 @@ async def finalize_weight_update(raw_request: Request): @router.get("/get_world_size") -async def get_world_size(raw_request: Request): - """Get the world size from the parallel config (TP * PP * DP).""" - world_size = engine_client( - raw_request - ).vllm_config.parallel_config.world_size_across_dp +async def get_world_size( + raw_request: Request, + include_dp: bool = Query(True), +): + """Get the world size from the parallel config. + + Args: + include_dp: If True (default), returns the world size including + data parallelism (TP * PP * DP). If False, returns the world + size without data parallelism (TP * PP). + """ + parallel_config = engine_client(raw_request).vllm_config.parallel_config + if include_dp: + world_size = parallel_config.world_size_across_dp + else: + world_size = parallel_config.world_size return JSONResponse(content={"world_size": world_size}) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 77799e1a92b6..c3a3ceffe97a 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -90,8 +90,13 @@ def __init__( self._sleep_saved_buffers: dict[str, torch.Tensor] = {} # Weight transfer engine (initialized on-demand) - self.weight_transfer_engine = WeightTransferEngineFactory.create_engine( - self.vllm_config.weight_transfer_config, self.vllm_config.parallel_config + self.weight_transfer_engine = ( + WeightTransferEngineFactory.create_engine( + self.vllm_config.weight_transfer_config, + self.vllm_config.parallel_config, + ) + if self.vllm_config.weight_transfer_config is not None + else None ) # Torch/CUDA profiler. Enabled and configured through profiler_config. @@ -934,6 +939,11 @@ def init_weight_transfer(self, init_info: dict) -> None: Args: init_info: Dictionary containing backend-specific initialization info """ + if self.weight_transfer_engine is None: + raise RuntimeError( + "Weight transfer not configured. " + "Please set weight_transfer_config to enable weight transfer." + ) # Parse dict into backend-specific typed dataclass typed_init_info = self.weight_transfer_engine.parse_init_info(init_info) self.weight_transfer_engine.init_transfer(typed_init_info) @@ -946,7 +956,10 @@ def update_weights(self, update_info: dict) -> None: update_info: Dictionary containing backend-specific update info """ if self.weight_transfer_engine is None: - raise RuntimeError("Weight transfer not initialized.") + raise RuntimeError( + "Weight transfer not configured. " + "Please set weight_transfer_config to enable weight transfer." + ) # Parse dict into backend-specific typed dataclass typed_update_info = self.weight_transfer_engine.parse_update_info(update_info) From d44f46ff0565e9e1c7777ca729e88812c4c4738d Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Mon, 26 Jan 2026 17:41:29 -0800 Subject: [PATCH 21/36] x Signed-off-by: ahao-anyscale --- vllm/entrypoints/serve/rlhf/api_router.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/entrypoints/serve/rlhf/api_router.py b/vllm/entrypoints/serve/rlhf/api_router.py index 2cbbac48ed93..c44014d8874a 100644 --- a/vllm/entrypoints/serve/rlhf/api_router.py +++ b/vllm/entrypoints/serve/rlhf/api_router.py @@ -7,6 +7,7 @@ from fastapi import APIRouter, FastAPI, HTTPException, Query, Request from fastapi.responses import JSONResponse +import vllm.envs as envs from vllm.distributed.weight_transfer.base import ( WeightTransferInitRequest, WeightUpdateRequest, @@ -165,4 +166,6 @@ async def get_world_size( def attach_router(app: FastAPI): + if not envs.VLLM_SERVER_DEV_MODE: + return app.include_router(router) From 844a84e93c5495cc8fd53eee9c8ea36cc880fa58 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Mon, 26 Jan 2026 17:46:33 -0800 Subject: [PATCH 22/36] x Signed-off-by: ahao-anyscale --- vllm/distributed/weight_transfer/nccl_engine.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/weight_transfer/nccl_engine.py b/vllm/distributed/weight_transfer/nccl_engine.py index 44f5a49e04ff..c317285544f1 100644 --- a/vllm/distributed/weight_transfer/nccl_engine.py +++ b/vllm/distributed/weight_transfer/nccl_engine.py @@ -178,6 +178,7 @@ def trainer_broadcast_weights( post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor] | None = None, packed: bool = False, + stream: torch.cuda.Stream | None = None, ) -> None: """Broadcast weights from trainer to vLLM workers. @@ -220,7 +221,9 @@ def trainer_broadcast_weights( # Use simple one-by-one broadcasting for item in iterator: tensor = post_iter_func(item) - group.broadcast(tensor, src=src, stream=torch.cuda.current_stream()) + group.broadcast( + tensor, src=src, stream=stream or torch.cuda.current_stream() + ) @staticmethod def stateless_init_process_group( From cb96ccd7ab42d8a8f13554068b3beab5dfe9fea5 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Tue, 27 Jan 2026 13:51:03 -0800 Subject: [PATCH 23/36] x Signed-off-by: ahao-anyscale --- .buildkite/test-amd.yaml | 5 +- .buildkite/test-pipeline.yaml | 5 +- .buildkite/test_areas/distributed.yaml | 5 +- .../new_weight_syncing/rlhf.py | 33 ++-- .../new_weight_syncing/rlhf_async_new_apis.py | 172 +++++++++--------- examples/online_serving/rlhf_http.py | 13 +- tests/distributed/test_packed_tensor.py | 56 ++---- tests/distributed/test_weight_transfer.py | 32 ---- vllm/distributed/weight_transfer/factory.py | 60 +----- .../weight_transfer/nccl_engine.py | 79 +++++++- .../weight_transfer/packed_tensor.py | 60 +++--- vllm/envs.py | 6 +- 12 files changed, 243 insertions(+), 283 deletions(-) diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index d1d48596e78f..05d4c5f468c8 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -227,7 +227,7 @@ steps: - tests/distributed/test_pynccl - tests/distributed/test_events - tests/compile/fullgraph/test_basic_correctness.py - - examples/offline_inference/weight_syncing/ + - examples/offline_inference/new_weight_syncing/ - tests/examples/offline_inference/data_parallel.py - tests/v1/distributed - tests/v1/engine/test_engine_core_client.py @@ -263,9 +263,8 @@ steps: - pytest -v -s distributed/test_symm_mem_allreduce.py # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests - - pushd ../examples/offline_inference/weight_syncing + - pushd ../examples/offline_inference/new_weight_syncing - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py - - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py - popd diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 057bc2eb3ab8..5d1bc65a3525 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -200,7 +200,7 @@ steps: - tests/distributed/test_pynccl - tests/distributed/test_events - tests/compile/fullgraph/test_basic_correctness.py - - examples/offline_inference/weight_syncing/ + - examples/offline_inference/new_weight_syncing/ - tests/examples/offline_inference/data_parallel.py - tests/v1/distributed - tests/v1/engine/test_engine_core_client.py @@ -235,9 +235,8 @@ steps: - pytest -v -s distributed/test_symm_mem_allreduce.py # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests - - pushd ../examples/offline_inference/weight_syncing + - pushd ../examples/offline_inference/new_weight_syncing - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py - - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py - popd diff --git a/.buildkite/test_areas/distributed.yaml b/.buildkite/test_areas/distributed.yaml index 330142452725..5e047371f6d8 100644 --- a/.buildkite/test_areas/distributed.yaml +++ b/.buildkite/test_areas/distributed.yaml @@ -61,7 +61,7 @@ steps: - tests/distributed/test_pynccl - tests/distributed/test_events - tests/compile/fullgraph/test_basic_correctness.py - - examples/offline_inference/weight_syncing/ + - examples/offline_inference/new_weight_syncing/ - tests/examples/offline_inference/data_parallel.py - tests/v1/distributed - tests/v1/engine/test_engine_core_client.py @@ -96,9 +96,8 @@ steps: - pytest -v -s distributed/test_symm_mem_allreduce.py # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests - - cd ../examples/offline_inference/weight_syncing + - cd ../examples/offline_inference/new_weight_syncing - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py - - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py - label: Distributed Tests (8 GPUs)(H100) diff --git a/examples/offline_inference/new_weight_syncing/rlhf.py b/examples/offline_inference/new_weight_syncing/rlhf.py index bb58f16ac96b..84548fbd9965 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf.py +++ b/examples/offline_inference/new_weight_syncing/rlhf.py @@ -1,22 +1,24 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray, +Demonstrates reinforcement learning using vLLM and Ray, with native weight syncing APIs at engine instance. The script separates training and inference workloads onto distinct GPUs so that Ray can manage process placement and inter-process communication. -A Hugging Face Transformer model occupies GPU 0 for training, whereas a -tensor-parallel vLLM inference engine occupies GPU 1–2. +A Hugging Face Transformer model occupies one GPU for training, whereas a +2x tensor-parallel vLLM inference engine occupies two GPUs. The example performs the following steps: - -* Load the training model on GPU 0. -* Split the inference model across GPUs 1–2 using vLLM's tensor parallelism - and Ray placement groups with dummy weights. -* Generate text from a list of prompts using the inference engine. +* Load the training model on one gpu (scheduled via ray) +* Initialize the inference model with dummy weights across + two gpus using vLLM's tensor parallelism and Ray placement groups. +* Generate gibberish from a list of prompts using the randomly initialized + inference engine. * Update the weights of the training model and broadcast the updated weights to the inference engine by using a Ray collective RPC group. +* Generating from the list of prompts after weight sync should result + in sensible outputs. This example assumes a single-node cluster with three GPUs, but Ray supports multi-node clusters. vLLM expects the GPUs are only used for vLLM @@ -75,13 +77,18 @@ def get_weight_metadata(self): def init_weight_transfer_group(self, master_address, master_port, world_size): """Initialize the NCCL process group for weight transfer.""" - self.model_update_group = NCCLWeightTransferEngine.stateless_init_process_group( - master_address, master_port, 0, world_size, self.device + self.model_update_group = NCCLWeightTransferEngine.trainer_init( + dict( + master_address=master_address, + master_port=master_port, + world_size=world_size, + ), + device=self.device, ) def broadcast_weights(self, packed: bool = True): """Broadcast weights to the inference engine.""" - NCCLWeightTransferEngine.trainer_broadcast_weights( + NCCLWeightTransferEngine.trainer_send_weights( iterator=self.model.named_parameters(), group=self.model_update_group, packed=packed, @@ -90,9 +97,7 @@ def broadcast_weights(self, packed: bool = True): # Initialize Ray and set the visible devices. The vLLM engine will # be placed on GPUs 1 and 2. -os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2" -ray.init(runtime_env={"excludes": [".git/objects/pack/"]}) -# ray.init() +ray.init() # Create a placement group that reserves GPU 1–2 for the vLLM inference engine. # Learn more about Ray placement groups: diff --git a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py index b6799fe7ed54..7668fcc38db2 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py +++ b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py @@ -1,24 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -Demonstrates asynchronous reinforcement learning from human feedback (RLHF) -using vLLM and Ray, with the new weight syncing APIs +Demonstrates async reinforcement learning using vLLM and Ray, +with native weight syncing APIs at engine instance. The script separates training and inference workloads onto distinct GPUs so that Ray can manage process placement and inter-process communication. -A Hugging Face Transformer model occupies GPU 0 for training, whereas a -tensor-parallel vLLM inference engine occupies GPU 1–2. +A Hugging Face Transformer model occupies one GPU for training, whereas a +2x tensor-parallel vLLM inference engine occupies two GPUs. The example performs the following steps: - -* Load the training model on GPU 0. -* Split the inference model across GPUs 1–2 using vLLM's tensor parallelism - and Ray placement groups. -* Start generation from a list of prompts using the inference engine. +* Load the training model on one gpu (scheduled via ray) +* Initialize the inference model with dummy weights across + two gpus using vLLM's tensor parallelism and Ray placement groups. +* Generate gibberish from a list of prompts using the randomly initialized + inference engine. * Pause generation once generation completes for one sequence * Update the weights of the training model and broadcast the updated weights - to the inference engine by using a Ray collective RPC group. Note that - for demonstration purposes we simply zero out the weights. + to the inference engine by using a Ray collective RPC group. * Resume generation and print out the results This example assumes a single-node cluster with three GPUs, but Ray @@ -27,7 +26,6 @@ causes unexpected behavior. """ -import asyncio import os import uuid from dataclasses import asdict @@ -36,7 +34,6 @@ import torch from ray.util.placement_group import placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -from rlhf_utils import stateless_init_process_group from transformers import AutoModelForCausalLM import vllm @@ -52,10 +49,45 @@ NCCLWeightTransferEngine, ) from vllm.utils.network_utils import get_ip, get_open_port +from vllm.v1.executor import Executor MODEL_NAME = "facebook/opt-125m" +class MyLLM(vllm.AsyncLLMEngine): + """Configure the vLLM worker for Ray placement group execution.""" + + def __init__(self, **kwargs): + os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0,1" + engine_args = vllm.AsyncEngineArgs(**kwargs) + vllm_config = engine_args.create_engine_config() + executor_class = Executor.get_class(vllm_config) + super().__init__( + vllm_config=vllm_config, + executor_class=executor_class, + log_requests=engine_args.enable_log_requests, + log_stats=not engine_args.disable_log_stats, + ) + + async def generate_with_retry( + self, prompt: str, sampling_params: vllm.SamplingParams + ) -> vllm.RequestOutput: + finish_reason = "abort" + while finish_reason == "abort": + async for request_output in self.generate( + prompt, sampling_params, request_id=str(uuid.uuid4()) + ): + output = request_output + finish_reason = output.outputs[0].finish_reason + if finish_reason == "abort": + print( + f"ABORT, prompt: [{prompt}], " + f"generated text: [{output.outputs[0].text}]" + ) + prompt += output.outputs[0].text + return output + + @ray.remote(num_gpus=1) class TrainModel: """Ray actor that wraps the training model on a dedicated GPU.""" @@ -86,83 +118,27 @@ def get_weight_metadata(self): def init_weight_transfer_group(self, master_address, master_port, world_size): """Initialize the NCCL process group for weight transfer.""" - self.model_update_group = stateless_init_process_group( - master_address, master_port, 0, world_size, self.device + self.model_update_group = NCCLWeightTransferEngine.trainer_init( + dict( + master_address=master_address, + master_port=master_port, + world_size=world_size, + ), + device=self.device, ) def broadcast_weights(self, packed: bool = True): """Broadcast weights to the inference engine.""" - NCCLWeightTransferEngine.trainer_broadcast_weights( + NCCLWeightTransferEngine.trainer_send_weights( iterator=self.model.named_parameters(), group=self.model_update_group, packed=packed, ) -class MyLLM: - """Simple wrapper over AsyncLLM for supporting async RL.""" - - def __init__(self, **kwargs): - os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0,1" - self.engine = vllm.AsyncLLMEngine.from_engine_args( - vllm.AsyncEngineArgs(**kwargs) - ) - self.generation_paused_event = asyncio.Event() - - async def generate( - self, prompt: str, sampling_params: vllm.SamplingParams - ) -> vllm.RequestOutput: - async for request_output in self.engine.generate( - prompt, sampling_params, request_id=str(uuid.uuid4()) - ): - final_output = request_output - return final_output - - async def generate_with_retry( - self, prompt: str, sampling_params: vllm.SamplingParams - ) -> vllm.RequestOutput: - finish_reason = "abort" - while finish_reason == "abort": - await self._wait_for_generation_to_resume() - output = await self.generate(prompt, sampling_params) - finish_reason = output.outputs[0].finish_reason - if finish_reason == "abort": - print(f"REQ ABORTED, prompt: {prompt}, text: {output.outputs[0].text}") - prompt += output.outputs[0].text - return output - - async def abort_generation(self) -> None: - self.generation_paused_event.set() - return await self.engine.pause_generation(wait_for_inflight_requests=False) - - async def resume_generation(self) -> None: - await self.engine.resume_generation() - self.generation_paused_event.clear() - - async def collective_rpc(self, method: str, args: tuple = ()): - return await self.engine.collective_rpc(method, args=args) - - async def _wait_for_generation_to_resume(self) -> None: - """Waits for generation to be resumed, intended for in-flight weight updates - and partial rollouts.""" - while self.generation_paused_event.is_set(): - await asyncio.sleep(0.5) - - async def init_weight_transfer(self, request: WeightTransferInitRequest) -> None: - print("reached init weight transfer") - return await self.engine.init_weight_transfer(request) - - async def update_weights(self, request: WeightUpdateRequest) -> None: - return await self.engine.update_weights(request) - - async def finalize_weight_update(self) -> None: - return await self.engine.finalize_weight_update() - - # Initialize Ray and set the visible devices. The vLLM engine will # be placed on GPUs 1 and 2. -os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2" -ray.init(runtime_env={"excludes": [".git/objects/pack/"]}) +ray.init() # Launch the training model actor. Ray's resource scheduler will allocate # 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs. @@ -193,6 +169,7 @@ async def finalize_weight_update(self) -> None: enforce_eager=True, tensor_parallel_size=2, distributed_executor_backend="ray", + load_format="dummy", weight_transfer_config=WeightTransferConfig(backend="nccl"), ) @@ -244,13 +221,13 @@ async def finalize_weight_update(self) -> None: finished, pending = ray.wait(generation_futures, num_returns=1) -# Abort generation in preparation for weight sync -ray.get(llm.abort_generation.remote()) +# Pause generation in preparation for weight sync +ray.get(llm.pause_generation.remote(wait_for_inflight_requests=False)) # Simulate a training step by zeroing out all model weights. # In a real RLHF training loop the weights would be updated using the gradient # from an RL objective such as PPO on a reward model. -ray.get(train_model.zero_weights.remote()) +# ray.get(train_model.zero_weights.remote()) # Synchronize the updated weights to the inference engine using batched API. # Collect all weight metadata from the training actor @@ -281,15 +258,30 @@ async def finalize_weight_update(self) -> None: # Resume generation since weight sync is complete ray.get(llm.resume_generation.remote()) -# Get all outputs -outputs = ray.get(finished) + ray.get(pending) +# Get outputs separately - finished completed before pause, pending were paused/resumed +finished_outputs = ray.get(finished) +pending_outputs = ray.get(pending) + +# Requests that finished before the pause: all generation used original weights +print("-" * 50) +print("Requests that completed BEFORE weight change:") +print("-" * 50) +for output in finished_outputs: + print(f"Prompt: {output.prompt!r}") + print(f"Generated (with original weights): {output.outputs[0].text!r}") + print("-" * 50) -# We expect the first output to be normal generation. -# The other outputs should have generated regular results midway -# and then have garbage tokens because we zero'd out the weights +# Requests that were paused mid-generation: some text before, some after weight change +print("Requests that were PAUSED and RESUMED after weight change:") print("-" * 50) -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") +for output in pending_outputs: + # Find the original prompt by checking which one this output started with + original_prompt = next(p for p in prompts if output.prompt.startswith(p)) + # output.prompt contains original prompt + text generated before pause + # output.outputs[0].text is what was generated after resuming with new weights + text_before_pause = output.prompt[len(original_prompt) :] + text_after_pause = output.outputs[0].text + print(f"Original prompt: {original_prompt!r}") + print(f"Generated before weight change: {text_before_pause!r}") + print(f"Generated after weight change: {text_after_pause!r}") print("-" * 50) diff --git a/examples/online_serving/rlhf_http.py b/examples/online_serving/rlhf_http.py index 77908d89fe72..cdfb97759593 100644 --- a/examples/online_serving/rlhf_http.py +++ b/examples/online_serving/rlhf_http.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Demonstrates reinforcement learning from human feedback (RLHF) using vLLM -via HTTP API, with new weight syncing APIs. +via HTTP API, with native weight syncing APIs. Unlike rlhf.py which creates a vLLM instance programmatically, this script assumes you have already started a vLLM server using `vllm serve`. It uses: @@ -187,8 +187,13 @@ def main(): init_thread.start() # Initialize NCCL process group on trainer side - model_update_group = NCCLWeightTransferEngine.stateless_init_process_group( - master_address, master_port, 0, world_size, torch.device(device) + model_update_group = NCCLWeightTransferEngine.trainer_init( + dict( + master_address=master_address, + master_port=master_port, + world_size=world_size, + ), + device=torch.device(device), ) # Wait for init_weight_transfer to complete @@ -217,7 +222,7 @@ def main(): # Broadcast all weights from trainer to vLLM workers print("Broadcasting weights via NCCL...") - NCCLWeightTransferEngine.trainer_broadcast_weights( + NCCLWeightTransferEngine.trainer_send_weights( iterator=train_model.named_parameters(), group=model_update_group, packed=True, diff --git a/tests/distributed/test_packed_tensor.py b/tests/distributed/test_packed_tensor.py index 02a787e04c8c..b67ee4fef5b4 100644 --- a/tests/distributed/test_packed_tensor.py +++ b/tests/distributed/test_packed_tensor.py @@ -11,6 +11,10 @@ from vllm import envs from vllm.distributed.weight_transfer.nccl_engine import NCCLUpdateInfo +from vllm.distributed.weight_transfer.packed_tensor import ( + packed_broadcast_consumer, + packed_broadcast_producer, +) class MockCommunicationGroup: @@ -96,17 +100,13 @@ class TestPackedBroadcastProducer: def test_producer_broadcasts_tensors(self, monkeypatch): """Test that producer broadcasts all tensors.""" - from vllm.distributed.weight_transfer.packed_tensor import ( - packed_broadcast_producer, - ) - params = create_mock_model_params() params_cuda = [(name, tensor.cuda()) for name, tensor in params] mock_group = MockCommunicationGroup() # Use a small target size to force multiple batches - monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE", 500) + monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES", 500) packed_broadcast_producer( iterator=iter(params_cuda), group=mock_group, @@ -120,10 +120,6 @@ def test_producer_broadcasts_tensors(self, monkeypatch): def test_producer_single_large_tensor(self, monkeypatch): """Test with a single tensor larger than target size.""" - from vllm.distributed.weight_transfer.packed_tensor import ( - packed_broadcast_producer, - ) - # Create a large tensor large_tensor = torch.randn(1000, 1000, dtype=torch.float32).cuda() params = [("large_weight", large_tensor)] @@ -131,7 +127,7 @@ def test_producer_single_large_tensor(self, monkeypatch): mock_group = MockCommunicationGroup() # Small target size to force the tensor to exceed it - monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE", 100) + monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES", 100) packed_broadcast_producer( iterator=iter(params), group=mock_group, @@ -150,10 +146,6 @@ def test_producer_single_large_tensor(self, monkeypatch): def test_producer_multiple_batches(self, monkeypatch): """Test that tensors are properly batched when exceeding target size.""" - from vllm.distributed.weight_transfer.packed_tensor import ( - packed_broadcast_producer, - ) - # Create many small tensors params = [ (f"weight_{i}", torch.randn(10, 10, dtype=torch.float32).cuda()) @@ -163,7 +155,7 @@ def test_producer_multiple_batches(self, monkeypatch): mock_group = MockCommunicationGroup() # Small target size to force multiple batches - monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE", 2000) + monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES", 2000) packed_broadcast_producer( iterator=iter(params), group=mock_group, @@ -181,13 +173,9 @@ def test_producer_multiple_batches(self, monkeypatch): def test_producer_empty_iterator(self, monkeypatch): """Test producer handles empty iterator gracefully.""" - from vllm.distributed.weight_transfer.packed_tensor import ( - packed_broadcast_producer, - ) - mock_group = MockCommunicationGroup() - monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE", 1000) + monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES", 1000) packed_broadcast_producer( iterator=iter([]), group=mock_group, @@ -208,18 +196,13 @@ class TestPackedBroadcastConsumer: def test_consumer_receives_tensors(self, monkeypatch): """Test that consumer receives and unpacks tensors.""" - from vllm.distributed.weight_transfer.packed_tensor import ( - packed_broadcast_consumer, - packed_broadcast_producer, - ) - params = create_mock_model_params() params_cuda = [(name, tensor.cuda()) for name, tensor in params] # First, run producer to get the broadcasted tensors producer_group = MockCommunicationGroup() - monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE", 2000) + monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES", 2000) packed_broadcast_producer( iterator=iter(params_cuda), group=producer_group, @@ -269,17 +252,12 @@ class TestPackedBroadcastRoundtrip: @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) def test_roundtrip_different_dtypes(self, dtype, monkeypatch): """Test roundtrip with different data types.""" - from vllm.distributed.weight_transfer.packed_tensor import ( - packed_broadcast_consumer, - packed_broadcast_producer, - ) - params = create_mock_model_params(num_layers=2, dtype=dtype) params_cuda = [(name, tensor.cuda()) for name, tensor in params] producer_group = MockCommunicationGroup() - monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE", 1000) + monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES", 1000) packed_broadcast_producer( iterator=iter(params_cuda), group=producer_group, @@ -314,11 +292,6 @@ def post_unpack_func(tensor_list): def test_roundtrip_mixed_dtypes(self, monkeypatch): """Test roundtrip with mixed data types.""" - from vllm.distributed.weight_transfer.packed_tensor import ( - packed_broadcast_consumer, - packed_broadcast_producer, - ) - # Create params with mixed dtypes params = [ ("layer1.weight", torch.randn(10, 20, dtype=torch.float32).cuda()), @@ -328,7 +301,7 @@ def test_roundtrip_mixed_dtypes(self, monkeypatch): producer_group = MockCommunicationGroup() - monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE", 500) + monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES", 500) packed_broadcast_producer( iterator=iter(params), group=producer_group, @@ -365,17 +338,12 @@ def post_unpack_func(tensor_list): @pytest.mark.parametrize("target_size", [100, 1000, 10000, 100000]) def test_roundtrip_different_batch_sizes(self, target_size, monkeypatch): """Test roundtrip with different target batch sizes.""" - from vllm.distributed.weight_transfer.packed_tensor import ( - packed_broadcast_consumer, - packed_broadcast_producer, - ) - params = create_mock_model_params(num_layers=5) params_cuda = [(name, tensor.cuda()) for name, tensor in params] producer_group = MockCommunicationGroup() - monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE", target_size) + monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES", target_size) packed_broadcast_producer( iterator=iter(params_cuda), group=producer_group, diff --git a/tests/distributed/test_weight_transfer.py b/tests/distributed/test_weight_transfer.py index b98882cf4934..8c16aebe0c10 100644 --- a/tests/distributed/test_weight_transfer.py +++ b/tests/distributed/test_weight_transfer.py @@ -6,7 +6,6 @@ Integration test for NCCL weight transfer between processes using Ray. """ -from dataclasses import dataclass from unittest.mock import MagicMock import pytest @@ -16,10 +15,6 @@ from vllm.config.parallel import ParallelConfig from vllm.config.weight_transfer import WeightTransferConfig from vllm.distributed.weight_transfer import WeightTransferEngineFactory -from vllm.distributed.weight_transfer.base import ( - BackendInitInfo, - WeightTransferEngine, -) from vllm.distributed.weight_transfer.nccl_engine import ( NCCLInitInfo, NCCLUpdateInfo, @@ -167,33 +162,6 @@ def test_create_engine_invalid_backend(self): with pytest.raises(ValueError, match="Invalid weight transfer backend"): WeightTransferEngineFactory.create_engine(config, parallel_config) - def test_register_custom_engine(self): - """Test registering a custom engine.""" - - @dataclass - class CustomInitInfo(BackendInitInfo): - pass - - class CustomEngine(WeightTransferEngine): - init_info_cls = CustomInitInfo - update_info_cls = NCCLUpdateInfo # Reuse for simplicity - - def init_transfer(self, init_info): - pass - - def receive_weights(self, update_info, load_weights): - pass - - def shutdown(self): - pass - - # Register custom engine - WeightTransferEngineFactory.register_engine("custom_test", CustomEngine) - assert WeightTransferEngineFactory.is_registered("custom_test") - - # Clean up - WeightTransferEngineFactory.unregister_engine("custom_test") - def test_register_duplicate_raises(self): """Test registering duplicate engine name raises.""" with pytest.raises(ValueError, match="already registered"): diff --git a/vllm/distributed/weight_transfer/factory.py b/vllm/distributed/weight_transfer/factory.py index 5f24b0062141..7235e30d1af6 100644 --- a/vllm/distributed/weight_transfer/factory.py +++ b/vllm/distributed/weight_transfer/factory.py @@ -89,68 +89,20 @@ def create_engine( ValueError: If the backend is not registered """ backend = config.backend - engine_cls = cls.get_engine_class(backend) - - logger.info( - "Creating weight transfer engine: %s", - engine_cls.__name__, - ) - - return engine_cls(config, parallel_config) - - @classmethod - def get_engine_class(cls, backend: str) -> type[WeightTransferEngine]: - """Get a registered engine class by name. - - Args: - backend: Name of the registered engine - - Returns: - The engine class - - Raises: - ValueError: If the engine is not registered - """ if backend not in cls._registry: available = list(cls._registry.keys()) raise ValueError( f"Invalid weight transfer backend: {backend}. " f"Available engines: {available}" ) - return cls._registry[backend]() - - @classmethod - def list_engines(cls) -> list[str]: - """List all registered engine names. - - Returns: - List of registered engine names - """ - return list(cls._registry.keys()) - - @classmethod - def unregister_engine(cls, name: str) -> None: - """Unregister an engine by name. - - Args: - name: Name of the engine to unregister + engine_cls = cls._registry[backend]() - Raises: - KeyError: If the engine is not registered - """ - del cls._registry[name] - - @classmethod - def is_registered(cls, name: str) -> bool: - """Check if an engine is registered. - - Args: - name: Name of the engine to check + logger.info( + "Creating weight transfer engine: %s", + engine_cls.__name__, + ) - Returns: - True if the engine is registered - """ - return name in cls._registry + return engine_cls(config, parallel_config) # Register built-in weight transfer engines here. diff --git a/vllm/distributed/weight_transfer/nccl_engine.py b/vllm/distributed/weight_transfer/nccl_engine.py index c317285544f1..74bedc4058a4 100644 --- a/vllm/distributed/weight_transfer/nccl_engine.py +++ b/vllm/distributed/weight_transfer/nccl_engine.py @@ -101,12 +101,14 @@ def init_transfer(self, init_info: NCCLInitInfo) -> None: worker_rank = dp_rank * world_size_per_dp + tp_rank rank = worker_rank + init_info.rank_offset # Create stateless process group - self.model_update_group = NCCLWeightTransferEngine.stateless_init_process_group( - init_info.master_address, - init_info.master_port, - rank, - init_info.world_size, - torch.cuda.current_device(), + self.model_update_group = ( + NCCLWeightTransferEngine._stateless_init_process_group( + init_info.master_address, + init_info.master_port, + rank, + init_info.world_size, + torch.cuda.current_device(), + ) ) def receive_weights( @@ -171,7 +173,7 @@ def shutdown(self) -> None: self.model_update_group = None @staticmethod - def trainer_broadcast_weights( + def trainer_send_weights( iterator: Iterator[tuple[str, torch.Tensor]], group: Any, src: int = 0, @@ -191,13 +193,15 @@ def trainer_broadcast_weights( packed: Whether to use packed tensor broadcasting for efficiency. When True, multiple tensors are batched together before broadcasting to reduce NCCL communication overhead. + stream: CUDA stream to use for broadcasting if packed is False. + If packed is True, new streams will be created for each buffer. Example: >>> from vllm.distributed.weight_transfer.nccl_engine import ( ... NCCLWeightTransferEngine, ... ) >>> param_iter = ((n, p) for n, p in model.named_parameters()) - >>> NCCLWeightTransferEngine.trainer_broadcast_weights( + >>> NCCLWeightTransferEngine.trainer_send_weights( ... param_iter, group, packed=True ... ) """ @@ -226,7 +230,64 @@ def trainer_broadcast_weights( ) @staticmethod - def stateless_init_process_group( + def trainer_init( + init_info: NCCLInitInfo | dict, + device: torch.device | int | str | None = None, + ) -> "PyNcclCommunicator": + """ + Initialize NCCL process group for trainer-side weight transfer. + + The trainer is always rank 0 in the process group. + + Args: + init_info: Either an NCCLInitInfo object or a dict with keys: + - master_address: str + - master_port: int + - world_size: int + Optionally can include 'device' if not passed separately. + device: The CUDA device to use. If not provided, must be in init_info. + + Returns: + PyNcclCommunicator for weight transfer. + + Example: + >>> from vllm.distributed.weight_transfer.nccl_engine import ( + ... NCCLWeightTransferEngine, + ... ) + >>> group = NCCLWeightTransferEngine.trainer_init( + ... dict( + ... master_address=master_address, + ... master_port=master_port, + ... world_size=world_size, + ... ), + ... device=torch.device("cuda:0"), + ... ) + """ + if isinstance(init_info, dict): + master_address = init_info["master_address"] + master_port = init_info["master_port"] + world_size = init_info["world_size"] + if device is None: + device = init_info.get("device") + else: + # NCCLInitInfo object + master_address = init_info.master_address + master_port = init_info.master_port + world_size = init_info.world_size + # NCCLInitInfo doesn't have device, so device param is required + + if device is None: + raise ValueError( + "device must be provided either in init_info or as a separate parameter" + ) + + # Trainer is always rank 0 + return NCCLWeightTransferEngine._stateless_init_process_group( + master_address, master_port, 0, world_size, device + ) + + @staticmethod + def _stateless_init_process_group( master_address, master_port, rank, world_size, device ): """ diff --git a/vllm/distributed/weight_transfer/packed_tensor.py b/vllm/distributed/weight_transfer/packed_tensor.py index 7366a939bf91..3fe25210a67f 100644 --- a/vllm/distributed/weight_transfer/packed_tensor.py +++ b/vllm/distributed/weight_transfer/packed_tensor.py @@ -27,7 +27,7 @@ def packed_broadcast_producer( packing, should return a tensor """ - target_packed_tensor_size = envs.VLLM_PACKED_TENSOR_BUFFER_SIZE + target_packed_tensor_size = envs.VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES num_buffers = envs.VLLM_PACKED_TENSOR_NUM_BUFFERS streams = [torch.cuda.Stream() for _ in range(num_buffers)] @@ -40,8 +40,6 @@ def packed_broadcast_producer( ] while True: - # Move to the next buffer - buffer_idx = (buffer_idx + 1) % num_buffers # Synchronize the current stream streams[buffer_idx].synchronize() # Start tasks for the new buffer in a new stream @@ -63,6 +61,8 @@ def packed_broadcast_producer( packing_tensor_list[buffer_idx], dim=0 ) group.broadcast(packed_tensors[buffer_idx], src=src) + # Move to the next buffer + buffer_idx = (buffer_idx + 1) % num_buffers except StopIteration: # Do the last broadcast if there are remaining tensors if len(packing_tensor_list[buffer_idx]) > 0: @@ -92,66 +92,64 @@ def packed_broadcast_consumer( def unpack_tensor( packed_tensor: torch.Tensor, - meta_data_list: list[tuple[str, list[int], torch.dtype, int, int]], + names: list[str], + shapes: list[list[int]], + dtypes: list[torch.dtype], + tensor_sizes: list[int], ) -> list[tuple[str, torch.Tensor]]: """Unpack a single tensor into a list of tensors. Args: packed_tensor: The packed torch.uint8 tensor to unpack - meta_data_list: List[(name, shape, dtype, offset, tensor_size)] + names: List of tensor names + shapes: List of tensor shapes + dtypes: List of tensor dtypes + tensor_sizes: List of tensor sizes in bytes Returns: unpacked List[(name, tensor)] """ - # Perform batched split with torch.split_with_sizes - packed_tensor_sizes = [meta[4] for meta in meta_data_list] - unpacked_tensors = packed_tensor.split_with_sizes(packed_tensor_sizes) + unpacked_tensors = packed_tensor.split_with_sizes(tensor_sizes) unpacked_list = [ - ( - meta_data_list[i][0], - tensor.view(meta_data_list[i][2]).view(*meta_data_list[i][1]), + (name, tensor.view(dtype).view(*shape)) + for name, shape, dtype, tensor in zip( + names, shapes, dtypes, unpacked_tensors ) - for i, tensor in enumerate(unpacked_tensors) ] return unpacked_list - target_packed_tensor_size = envs.VLLM_PACKED_TENSOR_BUFFER_SIZE + target_packed_tensor_size = envs.VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES num_buffers = envs.VLLM_PACKED_TENSOR_NUM_BUFFERS streams = [torch.cuda.Stream() for _ in range(num_buffers)] buffer_idx = 0 - packing_tensor_meta_data: list[ - list[tuple[str, list[int], torch.dtype, int, int]] - ] = [[] for _ in range(num_buffers)] + packing_tensor_meta_data: list[list[tuple[str, list[int], torch.dtype, int]]] = [ + [] for _ in range(num_buffers) + ] packing_tensor_sizes: list[int] = [0 for _ in range(num_buffers)] - offsets: list[int] = [0 for _ in range(num_buffers)] packed_tensors: list[torch.Tensor] = [ torch.empty(0, dtype=torch.uint8, device="cuda") for _ in range(num_buffers) ] while True: - # Move to the next buffer - buffer_idx = (buffer_idx + 1) % num_buffers # Synchronize the current stream streams[buffer_idx].synchronize() with torch.cuda.stream(streams[buffer_idx]): # Initialize the packing tensor meta data packing_tensor_meta_data[buffer_idx] = [] packing_tensor_sizes[buffer_idx] = 0 - offsets[buffer_idx] = 0 try: # Form a packed tensor while True: name, (shape, dtype) = next(iterator) tensor_size = math.prod(shape) * dtype.itemsize packing_tensor_meta_data[buffer_idx].append( - (name, shape, dtype, offsets[buffer_idx], tensor_size) + (name, shape, dtype, tensor_size) ) packing_tensor_sizes[buffer_idx] += tensor_size - offsets[buffer_idx] += tensor_size if packing_tensor_sizes[buffer_idx] > target_packed_tensor_size: break # Create a packed tensor and broadcast it @@ -160,12 +158,20 @@ def unpack_tensor( ) group.broadcast(packed_tensors[buffer_idx], src=src) # Load the packed tensor into the model + names, shapes, dtypes, tensor_sizes = zip( + *packing_tensor_meta_data[buffer_idx] + ) post_unpack_func( unpack_tensor( packed_tensors[buffer_idx], - packing_tensor_meta_data[buffer_idx], + list(names), + list(shapes), + list(dtypes), + list(tensor_sizes), ) ) + # Move to the next buffer + buffer_idx = (buffer_idx + 1) % num_buffers except StopIteration: # Do the last broadcast if there are remaining tensors if len(packing_tensor_meta_data[buffer_idx]) > 0: @@ -177,10 +183,16 @@ def unpack_tensor( ) group.broadcast(packed_tensors[buffer_idx], src=src) # Load the packed tensor into the model + names, shapes, dtypes, tensor_sizes = zip( + *packing_tensor_meta_data[buffer_idx] + ) post_unpack_func( unpack_tensor( packed_tensors[buffer_idx], - packing_tensor_meta_data[buffer_idx], + list(names), + list(shapes), + list(dtypes), + list(tensor_sizes), ) ) break diff --git a/vllm/envs.py b/vllm/envs.py index 948525561063..12058bdce5ae 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -258,7 +258,7 @@ VLLM_LOG_MODEL_INSPECTION: bool = False VLLM_DEBUG_MFU_METRICS: bool = False VLLM_PACKED_TENSOR_NUM_BUFFERS: int = 2 - VLLM_PACKED_TENSOR_BUFFER_SIZE: int = 1024 * 1024 * 1024 # 1GB + VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES: int = 1024 * 1024 * 1024 # 1GB def get_default_cache_root(): @@ -1681,8 +1681,8 @@ def _get_or_set_default() -> str: os.getenv("VLLM_PACKED_TENSOR_NUM_BUFFERS", "2") ), # Size in bytes for each packed tensor buffer (default 1GB) - "VLLM_PACKED_TENSOR_BUFFER_SIZE": lambda: int( - os.getenv("VLLM_PACKED_TENSOR_BUFFER_SIZE", str(1024 * 1024 * 1024)) + "VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES": lambda: int( + os.getenv("VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES", str(1024 * 1024 * 1024)) ), } From 6fb97778d2ea6a1000c8146f5bedb4f4ae808175 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Tue, 27 Jan 2026 14:05:36 -0800 Subject: [PATCH 24/36] x Signed-off-by: ahao-anyscale --- .../new_weight_syncing/rlhf_async_new_apis.py | 45 ++++++++++--------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py index 7668fcc38db2..d7a412ea8624 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py +++ b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py @@ -34,7 +34,7 @@ import torch from ray.util.placement_group import placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, AutoTokenizer import vllm from vllm import SamplingParams @@ -70,21 +70,23 @@ def __init__(self, **kwargs): ) async def generate_with_retry( - self, prompt: str, sampling_params: vllm.SamplingParams + self, prompt_token_ids: list[int], sampling_params: vllm.SamplingParams ) -> vllm.RequestOutput: finish_reason = "abort" while finish_reason == "abort": async for request_output in self.generate( - prompt, sampling_params, request_id=str(uuid.uuid4()) + {"prompt_token_ids": prompt_token_ids}, + sampling_params, + request_id=str(uuid.uuid4()), ): output = request_output finish_reason = output.outputs[0].finish_reason if finish_reason == "abort": print( - f"ABORT, prompt: [{prompt}], " - f"generated text: [{output.outputs[0].text}]" + f"ABORT, prompt_token_ids: {prompt_token_ids}, " + f"generated token_ids: {list(output.outputs[0].token_ids)}" ) - prompt += output.outputs[0].text + prompt_token_ids = prompt_token_ids + list(output.outputs[0].token_ids) return output @@ -100,11 +102,6 @@ def __init__(self, model_name: str): ) self.model.to(self.device) - def zero_weights(self): - """Zero out all model weights (simulates training step).""" - for name, p in self.model.named_parameters(): - p.data.zero_() - def get_weight_metadata(self): """Return weight names, dtypes, and shapes for weight transfer.""" names = [] @@ -181,6 +178,12 @@ def broadcast_weights(self, packed: bool = True): "The future of AI is", ] +# Tokenize prompts to token IDs +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +prompt_token_ids_list = [ + tokenizer.encode(prompt, add_special_tokens=False) for prompt in prompts +] + sampling_params = [ SamplingParams(temperature=0, max_tokens=2), SamplingParams(temperature=0, max_tokens=32), @@ -215,8 +218,8 @@ def broadcast_weights(self, packed: bool = True): generation_futures = [ - llm.generate_with_retry.remote(prompt, params) - for prompt, params in zip(prompts, sampling_params) + llm.generate_with_retry.remote(prompt_token_ids, params) + for prompt_token_ids, params in zip(prompt_token_ids_list, sampling_params) ] finished, pending = ray.wait(generation_futures, num_returns=1) @@ -224,11 +227,6 @@ def broadcast_weights(self, packed: bool = True): # Pause generation in preparation for weight sync ray.get(llm.pause_generation.remote(wait_for_inflight_requests=False)) -# Simulate a training step by zeroing out all model weights. -# In a real RLHF training loop the weights would be updated using the gradient -# from an RL objective such as PPO on a reward model. -# ray.get(train_model.zero_weights.remote()) - # Synchronize the updated weights to the inference engine using batched API. # Collect all weight metadata from the training actor names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote()) @@ -267,7 +265,8 @@ def broadcast_weights(self, packed: bool = True): print("Requests that completed BEFORE weight change:") print("-" * 50) for output in finished_outputs: - print(f"Prompt: {output.prompt!r}") + prompt_text = tokenizer.decode(output.prompt_token_ids) + print(f"Prompt: {prompt_text!r}") print(f"Generated (with original weights): {output.outputs[0].text!r}") print("-" * 50) @@ -275,11 +274,13 @@ def broadcast_weights(self, packed: bool = True): print("Requests that were PAUSED and RESUMED after weight change:") print("-" * 50) for output in pending_outputs: + # Decode the full prompt token IDs (original + generated before pause) + full_prompt_text = tokenizer.decode(output.prompt_token_ids) # Find the original prompt by checking which one this output started with - original_prompt = next(p for p in prompts if output.prompt.startswith(p)) - # output.prompt contains original prompt + text generated before pause + original_prompt = next(p for p in prompts if full_prompt_text.startswith(p)) + # output.prompt_token_ids contains original prompt + tokens generated before pause # output.outputs[0].text is what was generated after resuming with new weights - text_before_pause = output.prompt[len(original_prompt) :] + text_before_pause = full_prompt_text[len(original_prompt) :] text_after_pause = output.outputs[0].text print(f"Original prompt: {original_prompt!r}") print(f"Generated before weight change: {text_before_pause!r}") From baf5bcf97173ea13063a85a3212f21b4bfcc4903 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Tue, 27 Jan 2026 14:19:38 -0800 Subject: [PATCH 25/36] x Signed-off-by: ahao-anyscale --- tests/distributed/test_packed_tensor.py | 56 +++++++++++++++++++ .../weight_transfer/packed_tensor.py | 11 +++- 2 files changed, 64 insertions(+), 3 deletions(-) diff --git a/tests/distributed/test_packed_tensor.py b/tests/distributed/test_packed_tensor.py index b67ee4fef5b4..d574e404f5e4 100644 --- a/tests/distributed/test_packed_tensor.py +++ b/tests/distributed/test_packed_tensor.py @@ -376,3 +376,59 @@ def post_unpack_func(tensor_list): assert torch.allclose( unpacked_tensors[name], original_tensor, rtol=1e-5, atol=1e-7 ) + + def test_roundtrip_non_contiguous_tensors(self, monkeypatch): + """Test roundtrip with non-contiguous tensors from the trainer.""" + # Create non-contiguous tensors (simulating trainer outputs) + # Transposed tensors are non-contiguous + weight1 = torch.randn(20, 10, dtype=torch.float32).cuda().T + # Sliced tensors with step are non-contiguous + weight2 = torch.randn(40, 30, dtype=torch.float16).cuda()[::2, ::2] + # Permuted tensors are non-contiguous + weight3 = torch.randn(5, 10, 15, dtype=torch.bfloat16).cuda().permute(2, 0, 1) + + params = [ + ("layer1.weight", weight1), + ("layer2.weight", weight2), + ("layer3.weight", weight3), + ] + + # Verify tensors are indeed non-contiguous + for name, tensor in params: + assert not tensor.is_contiguous(), f"{name} should be non-contiguous" + + producer_group = MockCommunicationGroup() + + monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES", 500) + packed_broadcast_producer( + iterator=iter(params), + group=producer_group, + src=0, + post_iter_func=lambda x: x[1], + ) + + consumer_group = MockConsumerCommunicationGroup( + producer_group.broadcasted_tensors + ) + + state_dict_info = create_state_dict_info(params) + unpacked_tensors = {} + + def post_unpack_func(tensor_list): + for name, tensor in tensor_list: + unpacked_tensors[name] = tensor.clone() + + packed_broadcast_consumer( + iterator=iter(state_dict_info.items()), + group=consumer_group, + src=0, + post_unpack_func=post_unpack_func, + ) + + # Verify all non-contiguous params roundtrip correctly + for name, original_tensor in params: + assert name in unpacked_tensors + unpacked = unpacked_tensors[name] + assert unpacked.shape == original_tensor.shape + assert unpacked.dtype == original_tensor.dtype + assert torch.allclose(unpacked, original_tensor, rtol=1e-4, atol=1e-6) diff --git a/vllm/distributed/weight_transfer/packed_tensor.py b/vllm/distributed/weight_transfer/packed_tensor.py index 3fe25210a67f..3c5a06f600b6 100644 --- a/vllm/distributed/weight_transfer/packed_tensor.py +++ b/vllm/distributed/weight_transfer/packed_tensor.py @@ -51,7 +51,12 @@ def packed_broadcast_producer( # Pack the tensors while True: # Apply post processing and convert to linearized uint8 tensor - tensor = post_iter_func(next(iterator)).view(torch.uint8).view(-1) + tensor = ( + post_iter_func(next(iterator)) + .contiguous() + .view(torch.uint8) + .view(-1) + ) packing_tensor_list[buffer_idx].append(tensor) packing_tensor_sizes[buffer_idx] += tensor.numel() if packing_tensor_sizes[buffer_idx] > target_packed_tensor_size: @@ -109,10 +114,10 @@ def unpack_tensor( Returns: unpacked List[(name, tensor)] """ - unpacked_tensors = packed_tensor.split_with_sizes(tensor_sizes) + unpacked_tensors = packed_tensor.split(tensor_sizes) unpacked_list = [ - (name, tensor.view(dtype).view(*shape)) + (name, tensor.contiguous().view(dtype).view(*shape)) for name, shape, dtype, tensor in zip( names, shapes, dtypes, unpacked_tensors ) From 71364e12d75f825509da2907e51e3b1d20ced015 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Tue, 27 Jan 2026 16:19:02 -0800 Subject: [PATCH 26/36] x Signed-off-by: ahao-anyscale --- .buildkite/test-amd.yaml | 8 ++++++++ .buildkite/test-pipeline.yaml | 8 ++++++++ .buildkite/test_areas/distributed.yaml | 7 +++++++ .../new_weight_syncing/rlhf.py | 6 +----- .../new_weight_syncing/rlhf_async_new_apis.py | 6 +----- examples/online_serving/rlhf_http.py | 1 - vllm/distributed/weight_transfer/nccl_engine.py | 17 +++-------------- 7 files changed, 28 insertions(+), 25 deletions(-) diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 05d4c5f468c8..17dc284ea8f7 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -227,6 +227,8 @@ steps: - tests/distributed/test_pynccl - tests/distributed/test_events - tests/compile/fullgraph/test_basic_correctness.py + - examples/offline_inference/rlhf.py + - examples/offline_inference/rlhf_colocate.py - examples/offline_inference/new_weight_syncing/ - tests/examples/offline_inference/data_parallel.py - tests/v1/distributed @@ -263,6 +265,12 @@ steps: - pytest -v -s distributed/test_symm_mem_allreduce.py # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests + # OLD rlhf examples + - pushd ../examples/offline_inference + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py + - popd + # NEW rlhf examples - pushd ../examples/offline_inference/new_weight_syncing - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 5d1bc65a3525..d2547d329937 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -200,6 +200,8 @@ steps: - tests/distributed/test_pynccl - tests/distributed/test_events - tests/compile/fullgraph/test_basic_correctness.py + - examples/offline_inference/rlhf.py + - examples/offline_inference/rlhf_colocate.py - examples/offline_inference/new_weight_syncing/ - tests/examples/offline_inference/data_parallel.py - tests/v1/distributed @@ -235,6 +237,12 @@ steps: - pytest -v -s distributed/test_symm_mem_allreduce.py # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests + # OLD rlhf examples + - pushd ../examples/offline_inference + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py + - popd + # NEW rlhf examples - pushd ../examples/offline_inference/new_weight_syncing - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py diff --git a/.buildkite/test_areas/distributed.yaml b/.buildkite/test_areas/distributed.yaml index 5e047371f6d8..c616bcc8c734 100644 --- a/.buildkite/test_areas/distributed.yaml +++ b/.buildkite/test_areas/distributed.yaml @@ -61,6 +61,8 @@ steps: - tests/distributed/test_pynccl - tests/distributed/test_events - tests/compile/fullgraph/test_basic_correctness.py + - examples/offline_inference/rlhf.py + - examples/offline_inference/rlhf_colocate.py - examples/offline_inference/new_weight_syncing/ - tests/examples/offline_inference/data_parallel.py - tests/v1/distributed @@ -96,6 +98,11 @@ steps: - pytest -v -s distributed/test_symm_mem_allreduce.py # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests + # OLD rlhf examples + - cd ../examples/offline_inference + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py + # NEW rlhf examples - cd ../examples/offline_inference/new_weight_syncing - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py diff --git a/examples/offline_inference/new_weight_syncing/rlhf.py b/examples/offline_inference/new_weight_syncing/rlhf.py index 84548fbd9965..6e179cb7f765 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf.py +++ b/examples/offline_inference/new_weight_syncing/rlhf.py @@ -57,12 +57,9 @@ class TrainModel: """Ray actor that wraps the training model on a dedicated GPU.""" def __init__(self, model_name: str): - self.device = torch.device("cuda:0") - torch.cuda.set_device(self.device) self.model = AutoModelForCausalLM.from_pretrained( model_name, dtype=torch.bfloat16 - ) - self.model.to(self.device) + ).to("cuda:0") def get_weight_metadata(self): """Return weight names, dtypes, and shapes for weight transfer.""" @@ -83,7 +80,6 @@ def init_weight_transfer_group(self, master_address, master_port, world_size): master_port=master_port, world_size=world_size, ), - device=self.device, ) def broadcast_weights(self, packed: bool = True): diff --git a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py index d7a412ea8624..10446b11ffcc 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py +++ b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py @@ -95,12 +95,9 @@ class TrainModel: """Ray actor that wraps the training model on a dedicated GPU.""" def __init__(self, model_name: str): - self.device = torch.device("cuda:0") - torch.cuda.set_device(self.device) self.model = AutoModelForCausalLM.from_pretrained( model_name, dtype=torch.bfloat16 - ) - self.model.to(self.device) + ).to("cuda:0") def get_weight_metadata(self): """Return weight names, dtypes, and shapes for weight transfer.""" @@ -121,7 +118,6 @@ def init_weight_transfer_group(self, master_address, master_port, world_size): master_port=master_port, world_size=world_size, ), - device=self.device, ) def broadcast_weights(self, packed: bool = True): diff --git a/examples/online_serving/rlhf_http.py b/examples/online_serving/rlhf_http.py index cdfb97759593..40496ed6e99a 100644 --- a/examples/online_serving/rlhf_http.py +++ b/examples/online_serving/rlhf_http.py @@ -193,7 +193,6 @@ def main(): master_port=master_port, world_size=world_size, ), - device=torch.device(device), ) # Wait for init_weight_transfer to complete diff --git a/vllm/distributed/weight_transfer/nccl_engine.py b/vllm/distributed/weight_transfer/nccl_engine.py index 74bedc4058a4..8140c9f66a43 100644 --- a/vllm/distributed/weight_transfer/nccl_engine.py +++ b/vllm/distributed/weight_transfer/nccl_engine.py @@ -232,20 +232,18 @@ def trainer_send_weights( @staticmethod def trainer_init( init_info: NCCLInitInfo | dict, - device: torch.device | int | str | None = None, ) -> "PyNcclCommunicator": """ Initialize NCCL process group for trainer-side weight transfer. - The trainer is always rank 0 in the process group. + The trainer is always rank 0 in the process group. Uses the current + CUDA device (torch.cuda.current_device()). Args: init_info: Either an NCCLInitInfo object or a dict with keys: - master_address: str - master_port: int - world_size: int - Optionally can include 'device' if not passed separately. - device: The CUDA device to use. If not provided, must be in init_info. Returns: PyNcclCommunicator for weight transfer. @@ -260,30 +258,21 @@ def trainer_init( ... master_port=master_port, ... world_size=world_size, ... ), - ... device=torch.device("cuda:0"), ... ) """ if isinstance(init_info, dict): master_address = init_info["master_address"] master_port = init_info["master_port"] world_size = init_info["world_size"] - if device is None: - device = init_info.get("device") else: # NCCLInitInfo object master_address = init_info.master_address master_port = init_info.master_port world_size = init_info.world_size - # NCCLInitInfo doesn't have device, so device param is required - - if device is None: - raise ValueError( - "device must be provided either in init_info or as a separate parameter" - ) # Trainer is always rank 0 return NCCLWeightTransferEngine._stateless_init_process_group( - master_address, master_port, 0, world_size, device + master_address, master_port, 0, world_size, torch.cuda.current_device() ) @staticmethod From ee9a5b5a6ba6bd34f9839ee65311693b7a95447e Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Tue, 27 Jan 2026 18:06:09 -0800 Subject: [PATCH 27/36] x Signed-off-by: ahao-anyscale --- .../new_weight_syncing/rlhf.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/examples/offline_inference/new_weight_syncing/rlhf.py b/examples/offline_inference/new_weight_syncing/rlhf.py index 6e179cb7f765..9f0bdd06bf34 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf.py +++ b/examples/offline_inference/new_weight_syncing/rlhf.py @@ -60,6 +60,11 @@ def __init__(self, model_name: str): self.model = AutoModelForCausalLM.from_pretrained( model_name, dtype=torch.bfloat16 ).to("cuda:0") + self.port = get_open_port() + self.master_address = get_ip() + + def get_master_address_and_port(self): + return self.master_address, self.port def get_weight_metadata(self): """Return weight names, dtypes, and shapes for weight transfer.""" @@ -72,12 +77,12 @@ def get_weight_metadata(self): shapes.append(list(p.shape)) return names, dtype_names, shapes - def init_weight_transfer_group(self, master_address, master_port, world_size): + def init_weight_transfer_group(self, world_size): """Initialize the NCCL process group for weight transfer.""" self.model_update_group = NCCLWeightTransferEngine.trainer_init( dict( - master_address=master_address, - master_port=master_port, + master_address=self.master_address, + master_port=self.port, world_size=world_size, ), ) @@ -151,8 +156,7 @@ def broadcast_weights(self, packed: bool = True): # Set up the communication channel between the training process and the # inference engine. -master_address = get_ip() -master_port = get_open_port() +master_address, master_port = ray.get(train_model.get_master_address_and_port.remote()) world_size = ray.get(llm.get_world_size.remote()) + 1 # +1 for the trainer inference_handle = llm.init_weight_transfer.remote( @@ -167,9 +171,7 @@ def broadcast_weights(self, packed: bool = True): ) # Initialize weight transfer group on both the training actor and inference engine -train_handle = train_model.init_weight_transfer_group.remote( - master_address, master_port, world_size -) +train_handle = train_model.init_weight_transfer_group.remote(world_size) ray.get([train_handle, inference_handle]) # Synchronize the updated weights to the inference engine using batched API. From 19ed12a52c41ce53a9bc0e50657229dccf971020 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Tue, 27 Jan 2026 18:14:22 -0800 Subject: [PATCH 28/36] x Signed-off-by: ahao-anyscale --- .../new_weight_syncing/rlhf_async_new_apis.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py index 10446b11ffcc..7deef748774a 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py +++ b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py @@ -98,6 +98,11 @@ def __init__(self, model_name: str): self.model = AutoModelForCausalLM.from_pretrained( model_name, dtype=torch.bfloat16 ).to("cuda:0") + self.port = get_open_port() + self.master_address = get_ip() + + def get_master_address_and_port(self): + return self.master_address, self.port def get_weight_metadata(self): """Return weight names, dtypes, and shapes for weight transfer.""" @@ -110,12 +115,12 @@ def get_weight_metadata(self): shapes.append(list(p.shape)) return names, dtype_names, shapes - def init_weight_transfer_group(self, master_address, master_port, world_size): + def init_weight_transfer_group(self, world_size): """Initialize the NCCL process group for weight transfer.""" self.model_update_group = NCCLWeightTransferEngine.trainer_init( dict( - master_address=master_address, - master_port=master_port, + master_address=self.master_address, + master_port=self.port, world_size=world_size, ), ) @@ -189,8 +194,7 @@ def broadcast_weights(self, packed: bool = True): # Set up the communication channel between the training process and the # inference engine. -master_address = get_ip() -master_port = get_open_port() +master_address, master_port = ray.get(train_model.get_master_address_and_port.remote()) world_size = 3 # 1 trainer + 2 inference workers (tensor_parallel_size=2) inference_handle = llm.init_weight_transfer.remote( @@ -207,9 +211,7 @@ def broadcast_weights(self, packed: bool = True): ) # Initialize weight transfer group on both the training actor and inference engine -train_handle = train_model.init_weight_transfer_group.remote( - master_address, master_port, world_size -) +train_handle = train_model.init_weight_transfer_group.remote(world_size) ray.get([train_handle, inference_handle]) From a4b42392e59254768b7b64439fa368897aa38428 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Tue, 27 Jan 2026 21:59:09 -0800 Subject: [PATCH 29/36] x Signed-off-by: ahao-anyscale --- .buildkite/test_areas/distributed.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/test_areas/distributed.yaml b/.buildkite/test_areas/distributed.yaml index f780e0aab232..b37c7e05e311 100644 --- a/.buildkite/test_areas/distributed.yaml +++ b/.buildkite/test_areas/distributed.yaml @@ -103,7 +103,7 @@ steps: - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py # NEW rlhf examples - - cd ../examples/offline_inference/new_weight_syncing + - cd new_weight_syncing - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py From c7e89f7619b318b8928fddf34f35ef0ad827750d Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Thu, 29 Jan 2026 13:06:20 -0800 Subject: [PATCH 30/36] x Signed-off-by: ahao-anyscale --- .../new_weight_syncing/rlhf.py | 4 +- .../new_weight_syncing/rlhf_async_new_apis.py | 16 ++--- examples/online_serving/rlhf_http.py | 8 +-- tests/distributed/test_packed_tensor.py | 59 +++++++++++-------- tests/distributed/test_weight_transfer.py | 38 ++++++------ .../entrypoints/openai/test_openai_schema.py | 2 +- .../test_weight_transfer_llm.py | 45 +++++++------- vllm/distributed/weight_transfer/base.py | 12 ++-- .../weight_transfer/nccl_engine.py | 59 ++++++++++++------- .../weight_transfer/packed_tensor.py | 23 ++++++-- vllm/engine/protocol.py | 6 +- vllm/entrypoints/llm.py | 10 ++-- vllm/entrypoints/serve/rlhf/api_router.py | 10 ++-- vllm/envs.py | 10 ---- vllm/v1/engine/async_llm.py | 16 +++-- vllm/v1/worker/gpu_worker.py | 4 +- 16 files changed, 180 insertions(+), 142 deletions(-) diff --git a/examples/offline_inference/new_weight_syncing/rlhf.py b/examples/offline_inference/new_weight_syncing/rlhf.py index 9f0bdd06bf34..7fe5c7a66d06 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf.py +++ b/examples/offline_inference/new_weight_syncing/rlhf.py @@ -117,7 +117,7 @@ def broadcast_weights(self, packed: bool = True): # Launch the vLLM inference engine. The `enforce_eager` flag reduces # start-up latency. -# Note: Weight transfer APIs (init_weight_transfer, update_weights, +# Note: Weight transfer APIs (init_weight_transfer_engine, update_weights, # finalize_weight_update) are now native to vLLM workers. llm = ray.remote( num_cpus=0, @@ -159,7 +159,7 @@ def broadcast_weights(self, packed: bool = True): master_address, master_port = ray.get(train_model.get_master_address_and_port.remote()) world_size = ray.get(llm.get_world_size.remote()) + 1 # +1 for the trainer -inference_handle = llm.init_weight_transfer.remote( +inference_handle = llm.init_weight_transfer_engine.remote( dict( init_info=dict( master_address=master_address, diff --git a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py index 7deef748774a..5cdfe0adc6dd 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py +++ b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py @@ -41,12 +41,12 @@ from vllm.config import WeightTransferConfig from vllm.distributed.weight_transfer.base import ( WeightTransferInitRequest, - WeightUpdateRequest, + WeightTransferUpdateRequest, ) from vllm.distributed.weight_transfer.nccl_engine import ( - NCCLInitInfo, - NCCLUpdateInfo, NCCLWeightTransferEngine, + NCCLWeightTransferInitInfo, + NCCLWeightTransferUpdateInfo, ) from vllm.utils.network_utils import get_ip, get_open_port from vllm.v1.executor import Executor @@ -156,7 +156,7 @@ def broadcast_weights(self, packed: bool = True): # Launch the vLLM inference engine. The `enforce_eager` flag reduces # start-up latency. -# Note: Weight transfer APIs (init_weight_transfer, update_weights, +# Note: Weight transfer APIs (init_weight_transfer_engine, update_weights, # finalize_weight_update) are now native to vLLM workers. llm = ray.remote( num_cpus=0, @@ -197,10 +197,10 @@ def broadcast_weights(self, packed: bool = True): master_address, master_port = ray.get(train_model.get_master_address_and_port.remote()) world_size = 3 # 1 trainer + 2 inference workers (tensor_parallel_size=2) -inference_handle = llm.init_weight_transfer.remote( +inference_handle = llm.init_weight_transfer_engine.remote( WeightTransferInitRequest( init_info=asdict( - NCCLInitInfo( + NCCLWeightTransferInitInfo( master_address=master_address, master_port=master_port, rank_offset=1, @@ -232,9 +232,9 @@ def broadcast_weights(self, packed: bool = True): # Issue update_weights call with NCCL-specific update info # packed=True enables efficient batched tensor broadcasting inference_handle = llm.update_weights.remote( - WeightUpdateRequest( + WeightTransferUpdateRequest( update_info=asdict( - NCCLUpdateInfo( + NCCLWeightTransferUpdateInfo( names=names, dtype_names=dtype_names, shapes=shapes, diff --git a/examples/online_serving/rlhf_http.py b/examples/online_serving/rlhf_http.py index 40496ed6e99a..461a60a1fb8e 100644 --- a/examples/online_serving/rlhf_http.py +++ b/examples/online_serving/rlhf_http.py @@ -61,7 +61,7 @@ def generate_completions(client: OpenAI, model: str, prompts: list[str]) -> list return results -def init_weight_transfer( +def init_weight_transfer_engine( base_url: str, master_address: str, master_port: int, @@ -69,7 +69,7 @@ def init_weight_transfer( world_size: int, ) -> None: """Initialize weight transfer via HTTP endpoint.""" - url = f"{base_url}/init_weight_transfer" + url = f"{base_url}/init_weight_transfer_engine" payload = { "init_info": dict( master_address=master_address, @@ -181,7 +181,7 @@ def main(): import threading init_thread = threading.Thread( - target=init_weight_transfer, + target=init_weight_transfer_engine, args=(BASE_URL, master_address, master_port, rank_offset, world_size), ) init_thread.start() @@ -195,7 +195,7 @@ def main(): ), ) - # Wait for init_weight_transfer to complete + # Wait for init_weight_transfer_engine to complete init_thread.join() # Pause generation before weight sync diff --git a/tests/distributed/test_packed_tensor.py b/tests/distributed/test_packed_tensor.py index d574e404f5e4..134629e2b790 100644 --- a/tests/distributed/test_packed_tensor.py +++ b/tests/distributed/test_packed_tensor.py @@ -9,8 +9,7 @@ import pytest import torch -from vllm import envs -from vllm.distributed.weight_transfer.nccl_engine import NCCLUpdateInfo +from vllm.distributed.weight_transfer.nccl_engine import NCCLWeightTransferUpdateInfo from vllm.distributed.weight_transfer.packed_tensor import ( packed_broadcast_consumer, packed_broadcast_producer, @@ -65,15 +64,15 @@ def create_state_dict_info( return {name: (tuple(tensor.shape), tensor.dtype) for name, tensor in params} -# --- Unit Tests: NCCLUpdateInfo packed field --- +# --- Unit Tests: NCCLWeightTransferUpdateInfo packed field --- -class TestNCCLUpdateInfoPacked: - """Test NCCLUpdateInfo dataclass packed field.""" +class TestNCCLWeightTransferUpdateInfoPacked: + """Test NCCLWeightTransferUpdateInfo dataclass packed field.""" def test_packed_default_false(self): """Test that packed defaults to False.""" - info = NCCLUpdateInfo( + info = NCCLWeightTransferUpdateInfo( names=["layer.weight"], dtype_names=["float32"], shapes=[[10, 10]], @@ -82,7 +81,7 @@ def test_packed_default_false(self): def test_packed_can_be_set_true(self): """Test that packed can be set to True.""" - info = NCCLUpdateInfo( + info = NCCLWeightTransferUpdateInfo( names=["layer.weight"], dtype_names=["float32"], shapes=[[10, 10]], @@ -98,7 +97,7 @@ def test_packed_can_be_set_true(self): class TestPackedBroadcastProducer: """Test packed_broadcast_producer function.""" - def test_producer_broadcasts_tensors(self, monkeypatch): + def test_producer_broadcasts_tensors(self): """Test that producer broadcasts all tensors.""" params = create_mock_model_params() params_cuda = [(name, tensor.cuda()) for name, tensor in params] @@ -106,19 +105,19 @@ def test_producer_broadcasts_tensors(self, monkeypatch): mock_group = MockCommunicationGroup() # Use a small target size to force multiple batches - monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES", 500) packed_broadcast_producer( iterator=iter(params_cuda), group=mock_group, src=0, post_iter_func=lambda x: x[1], + buffer_size_bytes=500, ) # Should have broadcasted some tensors assert mock_group.broadcast_count > 0 assert len(mock_group.broadcasted_tensors) > 0 - def test_producer_single_large_tensor(self, monkeypatch): + def test_producer_single_large_tensor(self): """Test with a single tensor larger than target size.""" # Create a large tensor large_tensor = torch.randn(1000, 1000, dtype=torch.float32).cuda() @@ -127,12 +126,12 @@ def test_producer_single_large_tensor(self, monkeypatch): mock_group = MockCommunicationGroup() # Small target size to force the tensor to exceed it - monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES", 100) packed_broadcast_producer( iterator=iter(params), group=mock_group, src=0, post_iter_func=lambda x: x[1], + buffer_size_bytes=100, ) # Should still broadcast the tensor (at least 1 broadcast) @@ -144,7 +143,7 @@ def test_producer_single_large_tensor(self, monkeypatch): actual_size = sum(t.numel() for t in mock_group.broadcasted_tensors) assert actual_size == expected_size - def test_producer_multiple_batches(self, monkeypatch): + def test_producer_multiple_batches(self): """Test that tensors are properly batched when exceeding target size.""" # Create many small tensors params = [ @@ -155,12 +154,12 @@ def test_producer_multiple_batches(self, monkeypatch): mock_group = MockCommunicationGroup() # Small target size to force multiple batches - monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES", 2000) packed_broadcast_producer( iterator=iter(params), group=mock_group, src=0, post_iter_func=lambda x: x[1], + buffer_size_bytes=2000, ) # Should have multiple broadcasts @@ -171,16 +170,16 @@ def test_producer_multiple_batches(self, monkeypatch): actual_total = sum(t.numel() for t in mock_group.broadcasted_tensors) assert actual_total == expected_total - def test_producer_empty_iterator(self, monkeypatch): + def test_producer_empty_iterator(self): """Test producer handles empty iterator gracefully.""" mock_group = MockCommunicationGroup() - monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES", 1000) packed_broadcast_producer( iterator=iter([]), group=mock_group, src=0, post_iter_func=lambda x: x[1], + buffer_size_bytes=1000, ) # No broadcasts for empty iterator @@ -194,20 +193,22 @@ def test_producer_empty_iterator(self, monkeypatch): class TestPackedBroadcastConsumer: """Test packed_broadcast_consumer function.""" - def test_consumer_receives_tensors(self, monkeypatch): + def test_consumer_receives_tensors(self): """Test that consumer receives and unpacks tensors.""" params = create_mock_model_params() params_cuda = [(name, tensor.cuda()) for name, tensor in params] + buffer_size = 2000 + # First, run producer to get the broadcasted tensors producer_group = MockCommunicationGroup() - monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES", 2000) packed_broadcast_producer( iterator=iter(params_cuda), group=producer_group, src=0, post_iter_func=lambda x: x[1], + buffer_size_bytes=buffer_size, ) # Now run consumer with the broadcasted tensors @@ -228,6 +229,7 @@ def post_unpack_func(tensor_list): group=consumer_group, src=0, post_unpack_func=post_unpack_func, + buffer_size_bytes=buffer_size, ) # Verify all parameters were unpacked @@ -250,19 +252,20 @@ class TestPackedBroadcastRoundtrip: """Test producer-consumer roundtrip behavior.""" @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) - def test_roundtrip_different_dtypes(self, dtype, monkeypatch): + def test_roundtrip_different_dtypes(self, dtype): """Test roundtrip with different data types.""" params = create_mock_model_params(num_layers=2, dtype=dtype) params_cuda = [(name, tensor.cuda()) for name, tensor in params] + buffer_size = 1000 producer_group = MockCommunicationGroup() - monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES", 1000) packed_broadcast_producer( iterator=iter(params_cuda), group=producer_group, src=0, post_iter_func=lambda x: x[1], + buffer_size_bytes=buffer_size, ) consumer_group = MockConsumerCommunicationGroup( @@ -281,6 +284,7 @@ def post_unpack_func(tensor_list): group=consumer_group, src=0, post_unpack_func=post_unpack_func, + buffer_size_bytes=buffer_size, ) # Verify roundtrip preserves data @@ -290,7 +294,7 @@ def post_unpack_func(tensor_list): assert unpacked.dtype == dtype assert torch.allclose(unpacked, original_tensor, rtol=1e-4, atol=1e-6) - def test_roundtrip_mixed_dtypes(self, monkeypatch): + def test_roundtrip_mixed_dtypes(self): """Test roundtrip with mixed data types.""" # Create params with mixed dtypes params = [ @@ -299,14 +303,15 @@ def test_roundtrip_mixed_dtypes(self, monkeypatch): ("layer2.weight", torch.randn(20, 30, dtype=torch.bfloat16).cuda()), ] + buffer_size = 500 producer_group = MockCommunicationGroup() - monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES", 500) packed_broadcast_producer( iterator=iter(params), group=producer_group, src=0, post_iter_func=lambda x: x[1], + buffer_size_bytes=buffer_size, ) consumer_group = MockConsumerCommunicationGroup( @@ -325,6 +330,7 @@ def post_unpack_func(tensor_list): group=consumer_group, src=0, post_unpack_func=post_unpack_func, + buffer_size_bytes=buffer_size, ) # Verify all params roundtrip correctly with correct dtypes @@ -336,19 +342,19 @@ def post_unpack_func(tensor_list): assert torch.allclose(unpacked, original_tensor, rtol=1e-4, atol=1e-6) @pytest.mark.parametrize("target_size", [100, 1000, 10000, 100000]) - def test_roundtrip_different_batch_sizes(self, target_size, monkeypatch): + def test_roundtrip_different_batch_sizes(self, target_size): """Test roundtrip with different target batch sizes.""" params = create_mock_model_params(num_layers=5) params_cuda = [(name, tensor.cuda()) for name, tensor in params] producer_group = MockCommunicationGroup() - monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES", target_size) packed_broadcast_producer( iterator=iter(params_cuda), group=producer_group, src=0, post_iter_func=lambda x: x[1], + buffer_size_bytes=target_size, ) consumer_group = MockConsumerCommunicationGroup( @@ -367,6 +373,7 @@ def post_unpack_func(tensor_list): group=consumer_group, src=0, post_unpack_func=post_unpack_func, + buffer_size_bytes=target_size, ) # Verify all params roundtrip correctly @@ -377,7 +384,7 @@ def post_unpack_func(tensor_list): unpacked_tensors[name], original_tensor, rtol=1e-5, atol=1e-7 ) - def test_roundtrip_non_contiguous_tensors(self, monkeypatch): + def test_roundtrip_non_contiguous_tensors(self): """Test roundtrip with non-contiguous tensors from the trainer.""" # Create non-contiguous tensors (simulating trainer outputs) # Transposed tensors are non-contiguous @@ -397,14 +404,15 @@ def test_roundtrip_non_contiguous_tensors(self, monkeypatch): for name, tensor in params: assert not tensor.is_contiguous(), f"{name} should be non-contiguous" + buffer_size = 500 producer_group = MockCommunicationGroup() - monkeypatch.setattr(envs, "VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES", 500) packed_broadcast_producer( iterator=iter(params), group=producer_group, src=0, post_iter_func=lambda x: x[1], + buffer_size_bytes=buffer_size, ) consumer_group = MockConsumerCommunicationGroup( @@ -423,6 +431,7 @@ def post_unpack_func(tensor_list): group=consumer_group, src=0, post_unpack_func=post_unpack_func, + buffer_size_bytes=buffer_size, ) # Verify all non-contiguous params roundtrip correctly diff --git a/tests/distributed/test_weight_transfer.py b/tests/distributed/test_weight_transfer.py index 8c16aebe0c10..4c348dd799b5 100644 --- a/tests/distributed/test_weight_transfer.py +++ b/tests/distributed/test_weight_transfer.py @@ -16,9 +16,9 @@ from vllm.config.weight_transfer import WeightTransferConfig from vllm.distributed.weight_transfer import WeightTransferEngineFactory from vllm.distributed.weight_transfer.nccl_engine import ( - NCCLInitInfo, - NCCLUpdateInfo, NCCLWeightTransferEngine, + NCCLWeightTransferInitInfo, + NCCLWeightTransferUpdateInfo, ) from vllm.utils.network_utils import get_open_port @@ -36,15 +36,15 @@ def create_mock_parallel_config( return config -# --- Unit Tests: NCCLUpdateInfo Validation --- +# --- Unit Tests: NCCLWeightTransferUpdateInfo Validation --- -class TestNCCLUpdateInfoValidation: - """Test NCCLUpdateInfo dataclass validation.""" +class TestNCCLWeightTransferUpdateInfoValidation: + """Test NCCLWeightTransferUpdateInfo dataclass validation.""" def test_valid_update_info(self): - """Test creating valid NCCLUpdateInfo.""" - info = NCCLUpdateInfo( + """Test creating valid NCCLWeightTransferUpdateInfo.""" + info = NCCLWeightTransferUpdateInfo( names=["layer.weight", "layer.bias"], dtype_names=["float32", "float32"], shapes=[[10, 10], [10]], @@ -56,7 +56,7 @@ def test_valid_update_info(self): def test_mismatched_dtype_names_raises(self): """Test that mismatched dtype_names length raises ValueError.""" with pytest.raises(ValueError, match="dtype_names"): - NCCLUpdateInfo( + NCCLWeightTransferUpdateInfo( names=["layer.weight", "layer.bias"], dtype_names=["float32"], # Only one dtype shapes=[[10, 10], [10]], @@ -65,7 +65,7 @@ def test_mismatched_dtype_names_raises(self): def test_mismatched_shapes_raises(self): """Test that mismatched shapes length raises ValueError.""" with pytest.raises(ValueError, match="shapes"): - NCCLUpdateInfo( + NCCLWeightTransferUpdateInfo( names=["layer.weight", "layer.bias"], dtype_names=["float32", "float32"], shapes=[[10, 10]], # Only one shape @@ -73,7 +73,7 @@ def test_mismatched_shapes_raises(self): def test_empty_lists_valid(self): """Test that empty lists are valid.""" - info = NCCLUpdateInfo( + info = NCCLWeightTransferUpdateInfo( names=[], dtype_names=[], shapes=[], @@ -102,7 +102,7 @@ def test_parse_init_info_valid(self): } ) - assert isinstance(init_info, NCCLInitInfo) + assert isinstance(init_info, NCCLWeightTransferInitInfo) assert init_info.master_address == "127.0.0.1" assert init_info.master_port == 12345 assert init_info.rank_offset == 1 @@ -136,7 +136,7 @@ def test_parse_update_info_valid(self): } ) - assert isinstance(update_info, NCCLUpdateInfo) + assert isinstance(update_info, NCCLWeightTransferUpdateInfo) assert update_info.names == ["w1", "w2"] assert update_info.dtype_names == ["float32", "bfloat16"] assert update_info.shapes == [[100, 100], [50]] @@ -174,7 +174,7 @@ def test_register_duplicate_raises(self): def test_nccl_receive_weights_without_init_raises(): - """Test that receive_weights raises if init_transfer wasn't called.""" + """Test that receive_weights raises if init_transfer_engine wasn't called.""" if torch.cuda.device_count() < 1: pytest.skip("Need at least 1 GPU for this test") @@ -182,7 +182,7 @@ def test_nccl_receive_weights_without_init_raises(): parallel_config = create_mock_parallel_config() engine = NCCLWeightTransferEngine(config, parallel_config) - update_info = NCCLUpdateInfo( + update_info = NCCLWeightTransferUpdateInfo( names=["w"], dtype_names=["float32"], shapes=[[10]], @@ -244,9 +244,9 @@ def inference_receive_tensor( from vllm.config.parallel import ParallelConfig from vllm.config.weight_transfer import WeightTransferConfig from vllm.distributed.weight_transfer.nccl_engine import ( - NCCLInitInfo, - NCCLUpdateInfo, NCCLWeightTransferEngine, + NCCLWeightTransferInitInfo, + NCCLWeightTransferUpdateInfo, ) # Create engine with mock parallel config @@ -259,13 +259,13 @@ def inference_receive_tensor( engine = NCCLWeightTransferEngine(config, parallel_config) # Initialize the engine (joins as rank 1) - init_info = NCCLInitInfo( + init_info = NCCLWeightTransferInitInfo( master_address=master_address, master_port=master_port, rank_offset=1, # Trainer is rank 0, we become rank 1 world_size=world_size, ) - engine.init_transfer(init_info) + engine.init_transfer_engine(init_info) # Receive weights with a no-op load_weights that captures the tensor received_tensors = [] @@ -275,7 +275,7 @@ def noop_load_weights(weights: list[tuple[str, torch.Tensor]]): # Clone tensor to keep it after engine cleans up received_tensors.append((name, tensor.clone())) - update_info = NCCLUpdateInfo( + update_info = NCCLWeightTransferUpdateInfo( names=["test.weight"], dtype_names=[tensor_dtype], shapes=[tensor_shape], diff --git a/tests/entrypoints/openai/test_openai_schema.py b/tests/entrypoints/openai/test_openai_schema.py index 9e7758329904..89e58f833fe8 100644 --- a/tests/entrypoints/openai/test_openai_schema.py +++ b/tests/entrypoints/openai/test_openai_schema.py @@ -142,7 +142,7 @@ def test_openapi_stateless(case: schemathesis.Case): # Skip weight transfer endpoints as they require special setup # (weight_transfer_config) and are meant to be stateful. if case.operation.path in ( - "/init_weight_transfer", + "/init_weight_transfer_engine", "/update_weights", "/finalize_weight_update", ): diff --git a/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py b/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py index 1c0ac3bbe258..582ce5bd87ff 100644 --- a/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py +++ b/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py @@ -18,11 +18,11 @@ from vllm import LLM from vllm.config import WeightTransferConfig from vllm.distributed.weight_transfer.base import ( - BackendInitInfo, - BackendUpdateInfo, WeightTransferEngine, + WeightTransferInitInfo, WeightTransferInitRequest, - WeightUpdateRequest, + WeightTransferUpdateInfo, + WeightTransferUpdateRequest, ) from ...utils import create_new_process_for_each_test @@ -35,14 +35,14 @@ @dataclass -class MockInitInfo(BackendInitInfo): +class MockInitInfo(WeightTransferInitInfo): """Mock initialization info.""" test_param: str = "test" @dataclass -class MockUpdateInfo(BackendUpdateInfo): +class MockUpdateInfo(WeightTransferUpdateInfo): """Mock update info.""" names: list[str] | None = None @@ -57,7 +57,7 @@ class MockWeightTransferEngine(WeightTransferEngine[MockInitInfo, MockUpdateInfo update_info_cls = MockUpdateInfo # Class-level tracking for verification across processes - init_transfer_called: bool = False + init_transfer_engine_called: bool = False receive_weights_called: bool = False shutdown_called: bool = False last_init_info: MockInitInfo | None = None @@ -66,14 +66,14 @@ class MockWeightTransferEngine(WeightTransferEngine[MockInitInfo, MockUpdateInfo def __init__(self, config, parallel_config): super().__init__(config, parallel_config) # Reset tracking on init - MockWeightTransferEngine.init_transfer_called = False + MockWeightTransferEngine.init_transfer_engine_called = False MockWeightTransferEngine.receive_weights_called = False MockWeightTransferEngine.shutdown_called = False MockWeightTransferEngine.last_init_info = None MockWeightTransferEngine.last_update_info = None - def init_transfer(self, init_info: MockInitInfo) -> None: - MockWeightTransferEngine.init_transfer_called = True + def init_transfer_engine(self, init_info: MockInitInfo) -> None: + MockWeightTransferEngine.init_transfer_engine_called = True MockWeightTransferEngine.last_init_info = init_info def receive_weights( @@ -118,8 +118,9 @@ def test_get_world_size_tp1(): @create_new_process_for_each_test() -def test_init_weight_transfer_calls_engine(): - """Test that init_weight_transfer calls the engine's init_transfer method.""" +def test_init_weight_transfer_engine_calls_engine(): + """Test that init_weight_transfer_engine calls the engine's + init_transfer_engine method.""" if torch.cuda.device_count() < 1: pytest.skip("Need at least 1 GPU for this test") @@ -145,22 +146,22 @@ def check_engine_exists(self): results = llm.collective_rpc(check_engine_exists) assert all(results), "Weight transfer engine should be initialized" - # Call init_weight_transfer - llm.init_weight_transfer( + # Call init_weight_transfer_engine + llm.init_weight_transfer_engine( WeightTransferInitRequest(init_info={"test_param": "hello"}) ) - # Verify init_transfer was called on the engine + # Verify init_transfer_engine was called on the engine def check_init_called(self): engine = self.weight_transfer_engine return ( - engine.init_transfer_called, + engine.init_transfer_engine_called, engine.last_init_info.test_param if engine.last_init_info else None, ) results = llm.collective_rpc(check_init_called) for called, param in results: - assert called, "init_transfer should have been called" + assert called, "init_transfer_engine should have been called" assert param == "hello", f"Expected 'hello', got {param}" @@ -186,7 +187,7 @@ def test_update_weights_calls_engine(): ) # First init the weight transfer - llm.init_weight_transfer( + llm.init_weight_transfer_engine( WeightTransferInitRequest(init_info={"test_param": "init"}) ) @@ -196,7 +197,7 @@ def test_update_weights_calls_engine(): test_shapes = [[10, 10], [10]] llm.update_weights( - WeightUpdateRequest( + WeightTransferUpdateRequest( update_info={ "names": test_names, "dtype_names": test_dtypes, @@ -266,13 +267,13 @@ def test_full_weight_transfer_flow(): ) # Step 1: Initialize - llm.init_weight_transfer( + llm.init_weight_transfer_engine( WeightTransferInitRequest(init_info={"test_param": "flow_test"}) ) # Step 2: Update weights llm.update_weights( - WeightUpdateRequest( + WeightTransferUpdateRequest( update_info={ "names": ["test.weight"], "dtype_names": ["bfloat16"], @@ -288,7 +289,7 @@ def test_full_weight_transfer_flow(): def check_flow(self): engine = self.weight_transfer_engine return { - "init_called": engine.init_transfer_called, + "init_called": engine.init_transfer_engine_called, "update_called": engine.receive_weights_called, "init_param": ( engine.last_init_info.test_param if engine.last_init_info else None @@ -300,7 +301,7 @@ def check_flow(self): results = llm.collective_rpc(check_flow) for result in results: - assert result["init_called"], "init_transfer should be called" + assert result["init_called"], "init_transfer_engine should be called" assert result["update_called"], "receive_weights should be called" assert result["init_param"] == "flow_test" assert result["update_names"] == ["test.weight"] diff --git a/vllm/distributed/weight_transfer/base.py b/vllm/distributed/weight_transfer/base.py index 97ae60e4ab46..69c0cbee9ce2 100644 --- a/vllm/distributed/weight_transfer/base.py +++ b/vllm/distributed/weight_transfer/base.py @@ -12,20 +12,20 @@ from vllm.config.parallel import ParallelConfig from vllm.config.weight_transfer import WeightTransferConfig -TInitInfo = TypeVar("TInitInfo", bound="BackendInitInfo") -TUpdateInfo = TypeVar("TUpdateInfo", bound="BackendUpdateInfo") +TInitInfo = TypeVar("TInitInfo", bound="WeightTransferInitInfo") +TUpdateInfo = TypeVar("TUpdateInfo", bound="WeightTransferUpdateInfo") # Base protocols for backend-specific dataclasses @dataclass -class BackendInitInfo(ABC): # noqa: B024 +class WeightTransferInitInfo(ABC): # noqa: B024 """Base class for backend-specific initialization info.""" pass @dataclass -class BackendUpdateInfo(ABC): # noqa: B024 +class WeightTransferUpdateInfo(ABC): # noqa: B024 """Base class for backend-specific weight update info.""" pass @@ -40,7 +40,7 @@ class WeightTransferInitRequest: @dataclass -class WeightUpdateRequest: +class WeightTransferUpdateRequest: """API-level weight update request.""" update_info: dict[str, Any] = field(default_factory=dict) @@ -118,7 +118,7 @@ def parse_update_info(self, update_dict: dict[str, Any]) -> TUpdateInfo: ) from e @abstractmethod - def init_transfer(self, init_info: TInitInfo) -> None: + def init_transfer_engine(self, init_info: TInitInfo) -> None: """ Initialize the weight transfer mechanism. This is called once at the beginning of training. diff --git a/vllm/distributed/weight_transfer/nccl_engine.py b/vllm/distributed/weight_transfer/nccl_engine.py index 8140c9f66a43..5c90198bf616 100644 --- a/vllm/distributed/weight_transfer/nccl_engine.py +++ b/vllm/distributed/weight_transfer/nccl_engine.py @@ -14,14 +14,19 @@ from vllm.config.parallel import ParallelConfig from vllm.config.weight_transfer import WeightTransferConfig from vllm.distributed.weight_transfer.base import ( - BackendInitInfo, - BackendUpdateInfo, WeightTransferEngine, + WeightTransferInitInfo, + WeightTransferUpdateInfo, +) +from vllm.distributed.weight_transfer.packed_tensor import ( + DEFAULT_PACKED_BUFFER_SIZE_BYTES, + DEFAULT_PACKED_NUM_BUFFERS, + packed_broadcast_consumer, ) @dataclass -class NCCLInitInfo(BackendInitInfo): +class NCCLWeightTransferInitInfo(WeightTransferInitInfo): """Initialization info for NCCL weight transfer backend.""" master_address: str @@ -31,7 +36,7 @@ class NCCLInitInfo(BackendInitInfo): @dataclass -class NCCLUpdateInfo(BackendUpdateInfo): +class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo): """Update info for NCCL weight transfer backend.""" names: list[str] @@ -41,6 +46,12 @@ class NCCLUpdateInfo(BackendUpdateInfo): """Whether to use packed tensor broadcasting for efficiency. When True, multiple tensors are batched together before broadcasting to reduce NCCL communication overhead.""" + packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES + """Size in bytes for each packed tensor buffer. Default is 1GB. + Both producer and consumer must use the same value.""" + packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS + """Number of buffers for double/triple buffering during packed transfer. + Both producer and consumer must use the same value.""" def __post_init__(self): """Validate that all lists have the same length.""" @@ -57,7 +68,9 @@ def __post_init__(self): ) -class NCCLWeightTransferEngine(WeightTransferEngine[NCCLInitInfo, NCCLUpdateInfo]): +class NCCLWeightTransferEngine( + WeightTransferEngine[NCCLWeightTransferInitInfo, NCCLWeightTransferUpdateInfo] +): """ Weight transfer engine using NCCL for communication between trainer and workers. @@ -66,8 +79,8 @@ class NCCLWeightTransferEngine(WeightTransferEngine[NCCLInitInfo, NCCLUpdateInfo """ # Define backend-specific dataclass types - init_info_cls = NCCLInitInfo - update_info_cls = NCCLUpdateInfo + init_info_cls = NCCLWeightTransferInitInfo + update_info_cls = NCCLWeightTransferUpdateInfo def __init__( self, config: WeightTransferConfig, parallel_config: ParallelConfig @@ -82,7 +95,7 @@ def __init__( super().__init__(config, parallel_config) self.model_update_group: PyNcclCommunicator | None = None - def init_transfer(self, init_info: NCCLInitInfo) -> None: + def init_transfer_engine(self, init_info: NCCLWeightTransferInitInfo) -> None: """ Initialize NCCL process group with the trainer. @@ -95,10 +108,10 @@ def init_transfer(self, init_info: NCCLInitInfo) -> None: # Must account for data parallel to get unique ranks across all workers dp_rank = self.parallel_config.data_parallel_rank world_size_per_dp = self.parallel_config.world_size # TP * PP - tp_rank = self.parallel_config.rank + rank_within_dp = self.parallel_config.rank # Unique rank across all DP groups - worker_rank = dp_rank * world_size_per_dp + tp_rank + worker_rank = dp_rank * world_size_per_dp + rank_within_dp rank = worker_rank + init_info.rank_offset # Create stateless process group self.model_update_group = ( @@ -113,7 +126,7 @@ def init_transfer(self, init_info: NCCLInitInfo) -> None: def receive_weights( self, - update_info: NCCLUpdateInfo, + update_info: NCCLWeightTransferUpdateInfo, load_weights: Callable[[list[tuple[str, torch.Tensor]]], None], ) -> None: """ @@ -131,15 +144,11 @@ def receive_weights( """ if self.model_update_group is None: raise RuntimeError( - "NCCL weight transfer not initialized. Call init_transfer() first." + "NCCL weight transfer not initialized. " + "Call init_transfer_engine() first." ) if update_info.packed: - # Use packed tensor broadcasting for efficiency - from vllm.distributed.weight_transfer.packed_tensor import ( - packed_broadcast_consumer, - ) - # Build iterator of (name, (shape, dtype)) from update_info def state_dict_info_iterator(): for name, dtype_name, shape in zip( @@ -153,6 +162,8 @@ def state_dict_info_iterator(): group=self.model_update_group, src=0, post_unpack_func=load_weights, + buffer_size_bytes=update_info.packed_buffer_size_bytes, + num_buffers=update_info.packed_num_buffers, ) else: # Use simple one-by-one broadcasting @@ -181,6 +192,8 @@ def trainer_send_weights( | None = None, packed: bool = False, stream: torch.cuda.Stream | None = None, + packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES, + packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS, ) -> None: """Broadcast weights from trainer to vLLM workers. @@ -195,6 +208,10 @@ def trainer_send_weights( broadcasting to reduce NCCL communication overhead. stream: CUDA stream to use for broadcasting if packed is False. If packed is True, new streams will be created for each buffer. + packed_buffer_size_bytes: Size in bytes for each packed tensor buffer. + Must match the value used in NCCLWeightTransferUpdateInfo. + packed_num_buffers: Number of buffers for double/triple buffering. + Must match the value used in NCCLWeightTransferUpdateInfo. Example: >>> from vllm.distributed.weight_transfer.nccl_engine import ( @@ -220,6 +237,8 @@ def trainer_send_weights( group=group, src=src, post_iter_func=post_iter_func, + buffer_size_bytes=packed_buffer_size_bytes, + num_buffers=packed_num_buffers, ) else: # Use simple one-by-one broadcasting @@ -231,7 +250,7 @@ def trainer_send_weights( @staticmethod def trainer_init( - init_info: NCCLInitInfo | dict, + init_info: NCCLWeightTransferInitInfo | dict, ) -> "PyNcclCommunicator": """ Initialize NCCL process group for trainer-side weight transfer. @@ -240,7 +259,7 @@ def trainer_init( CUDA device (torch.cuda.current_device()). Args: - init_info: Either an NCCLInitInfo object or a dict with keys: + init_info: Either an NCCLWeightTransferInitInfo object or a dict with keys: - master_address: str - master_port: int - world_size: int @@ -265,7 +284,7 @@ def trainer_init( master_port = init_info["master_port"] world_size = init_info["world_size"] else: - # NCCLInitInfo object + # NCCLWeightTransferInitInfo object master_address = init_info.master_address master_port = init_info.master_port world_size = init_info.world_size diff --git a/vllm/distributed/weight_transfer/packed_tensor.py b/vllm/distributed/weight_transfer/packed_tensor.py index 3c5a06f600b6..1c96d72edac7 100644 --- a/vllm/distributed/weight_transfer/packed_tensor.py +++ b/vllm/distributed/weight_transfer/packed_tensor.py @@ -8,7 +8,10 @@ import torch -from vllm import envs +# Default values for packed tensor configuration. +# These are imported by NCCLWeightTransferUpdateInfo and trainer_send_weights. +DEFAULT_PACKED_BUFFER_SIZE_BYTES = 1024 * 1024 * 1024 # 1GB +DEFAULT_PACKED_NUM_BUFFERS = 2 def packed_broadcast_producer( @@ -16,6 +19,8 @@ def packed_broadcast_producer( group: Any, src: int, post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor], + buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES, + num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS, ) -> None: """Broadcast tensors in a packed manner from trainer to workers. @@ -25,10 +30,13 @@ def packed_broadcast_producer( src: Source rank (0 in current implementation) post_iter_func: Function to apply to each (name, tensor) pair before packing, should return a tensor + buffer_size_bytes: Size in bytes for each packed tensor buffer. + Both producer and consumer must use the same value. + num_buffers: Number of buffers for double/triple buffering. + Both producer and consumer must use the same value. """ - target_packed_tensor_size = envs.VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES - num_buffers = envs.VLLM_PACKED_TENSOR_NUM_BUFFERS + target_packed_tensor_size = buffer_size_bytes streams = [torch.cuda.Stream() for _ in range(num_buffers)] buffer_idx = 0 @@ -83,6 +91,8 @@ def packed_broadcast_consumer( group: Any, src: int, post_unpack_func: Callable[[list[tuple[str, torch.Tensor]]], None], + buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES, + num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS, ) -> None: """Consume packed tensors and unpack them into a list of tensors. @@ -92,6 +102,10 @@ def packed_broadcast_consumer( src: Source rank (0 in current implementation) post_unpack_func: Function to apply to each list of (name, tensor) after unpacking + buffer_size_bytes: Size in bytes for each packed tensor buffer. + Both producer and consumer must use the same value. + num_buffers: Number of buffers for double/triple buffering. + Both producer and consumer must use the same value. """ @@ -125,8 +139,7 @@ def unpack_tensor( return unpacked_list - target_packed_tensor_size = envs.VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES - num_buffers = envs.VLLM_PACKED_TENSOR_NUM_BUFFERS + target_packed_tensor_size = buffer_size_bytes streams = [torch.cuda.Stream() for _ in range(num_buffers)] buffer_idx = 0 diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 8f1fec850076..bbd09502c6c0 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -8,7 +8,7 @@ from vllm.config import ModelConfig, VllmConfig from vllm.distributed.weight_transfer.base import ( WeightTransferInitRequest, - WeightUpdateRequest, + WeightTransferUpdateRequest, ) from vllm.inputs.data import PromptType from vllm.lora.request import LoRARequest @@ -196,13 +196,13 @@ async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: """Get supported tasks""" raise NotImplementedError - async def init_weight_transfer( + async def init_weight_transfer_engine( self, init_request: WeightTransferInitRequest ) -> None: """Initialize weight transfer for RL training.""" raise NotImplementedError - async def update_weights(self, request: WeightUpdateRequest) -> None: + async def update_weights(self, request: WeightTransferUpdateRequest) -> None: """Batched weight update for RL training.""" raise NotImplementedError diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index dbd1b5763be9..a65eb4935c07 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -35,7 +35,7 @@ ) from vllm.distributed.weight_transfer.base import ( WeightTransferInitRequest, - WeightUpdateRequest, + WeightTransferUpdateRequest, ) from vllm.engine.arg_utils import EngineArgs from vllm.entrypoints.chat_utils import ( @@ -1722,7 +1722,9 @@ def _run_engine( # its previous requests. return sorted(outputs, key=lambda x: int(x.request_id)) - def init_weight_transfer(self, request: WeightTransferInitRequest | dict) -> None: + def init_weight_transfer_engine( + self, request: WeightTransferInitRequest | dict + ) -> None: """ Initialize weight transfer for RL training. @@ -1734,10 +1736,10 @@ def init_weight_transfer(self, request: WeightTransferInitRequest | dict) -> Non ) self.llm_engine.collective_rpc( - "init_weight_transfer", kwargs={"init_info": init_info_dict} + "init_weight_transfer_engine", kwargs={"init_info": init_info_dict} ) - def update_weights(self, request: WeightUpdateRequest | dict) -> None: + def update_weights(self, request: WeightTransferUpdateRequest | dict) -> None: """ Update the weights of the model. diff --git a/vllm/entrypoints/serve/rlhf/api_router.py b/vllm/entrypoints/serve/rlhf/api_router.py index c44014d8874a..87b17fe8ff9c 100644 --- a/vllm/entrypoints/serve/rlhf/api_router.py +++ b/vllm/entrypoints/serve/rlhf/api_router.py @@ -10,7 +10,7 @@ import vllm.envs as envs from vllm.distributed.weight_transfer.base import ( WeightTransferInitRequest, - WeightUpdateRequest, + WeightTransferUpdateRequest, ) from vllm.engine.protocol import EngineClient from vllm.logger import init_logger @@ -103,8 +103,8 @@ async def is_paused(raw_request: Request) -> JSONResponse: return JSONResponse(content={"is_paused": paused}) -@router.post("/init_weight_transfer") -async def init_weight_transfer(raw_request: Request): +@router.post("/init_weight_transfer_engine") +async def init_weight_transfer_engine(raw_request: Request): try: body = await raw_request.json() except json.JSONDecodeError as e: @@ -115,7 +115,7 @@ async def init_weight_transfer(raw_request: Request): status_code=HTTPStatus.BAD_REQUEST.value, detail="Missing 'init_info' in request body", ) - await engine_client(raw_request).init_weight_transfer( + await engine_client(raw_request).init_weight_transfer_engine( WeightTransferInitRequest(init_info=init_info) ) return JSONResponse(content={"message": "Weight transfer initialized"}) @@ -134,7 +134,7 @@ async def update_weights(raw_request: Request): detail="Missing 'update_info' in request body", ) await engine_client(raw_request).update_weights( - request=WeightUpdateRequest(update_info=update_info) + request=WeightTransferUpdateRequest(update_info=update_info) ) return JSONResponse(content={"message": "Weights updated"}) diff --git a/vllm/envs.py b/vllm/envs.py index 24c00acd1189..ad220a979d44 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -251,8 +251,6 @@ VLLM_USE_V2_MODEL_RUNNER: bool = False VLLM_LOG_MODEL_INSPECTION: bool = False VLLM_DEBUG_MFU_METRICS: bool = False - VLLM_PACKED_TENSOR_NUM_BUFFERS: int = 2 - VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES: int = 1024 * 1024 * 1024 # 1GB VLLM_DISABLE_LOG_LOGO: bool = False VLLM_LORA_DISABLE_PDL: bool = False @@ -1623,14 +1621,6 @@ def _get_or_set_default() -> str: "VLLM_DEBUG_MFU_METRICS": lambda: bool( int(os.getenv("VLLM_DEBUG_MFU_METRICS", "0")) ), - # Number of buffers for packed tensor weight transfer in NCCLWeightTransferEngine - "VLLM_PACKED_TENSOR_NUM_BUFFERS": lambda: int( - os.getenv("VLLM_PACKED_TENSOR_NUM_BUFFERS", "2") - ), - # Size in bytes for each packed tensor buffer (default 1GB) - "VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES": lambda: int( - os.getenv("VLLM_PACKED_TENSOR_BUFFER_SIZE_BYTES", str(1024 * 1024 * 1024)) - ), # Disable logging of vLLM logo at server startup time. "VLLM_DISABLE_LOG_LOGO": lambda: bool(int(os.getenv("VLLM_DISABLE_LOG_LOGO", "0"))), # Disable PDL for LoRA, as enabling PDL with LoRA on SM100 causes diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index a4c729e5897e..4865cd482e72 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -17,7 +17,7 @@ from vllm.config import VllmConfig from vllm.distributed.weight_transfer.base import ( WeightTransferInitRequest, - WeightUpdateRequest, + WeightTransferUpdateRequest, ) from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import EngineClient @@ -1026,7 +1026,9 @@ def errored(self) -> bool: def dead_error(self) -> BaseException: return EngineDeadError() - async def init_weight_transfer(self, request: WeightTransferInitRequest) -> None: + async def init_weight_transfer_engine( + self, request: WeightTransferInitRequest + ) -> None: """ Initialize weight transfer for RL training. @@ -1043,10 +1045,10 @@ async def init_weight_transfer(self, request: WeightTransferInitRequest) -> None raise TypeError(f"Expected WeightTransferInitRequest, got {type(request)}") await self.collective_rpc( - "init_weight_transfer", kwargs={"init_info": init_info_dict} + "init_weight_transfer_engine", kwargs={"init_info": init_info_dict} ) - async def update_weights(self, request: WeightUpdateRequest) -> None: + async def update_weights(self, request: WeightTransferUpdateRequest) -> None: """ Batched weight update for RL training. @@ -1054,10 +1056,12 @@ async def update_weights(self, request: WeightUpdateRequest) -> None: request: Weight update request with backend-specific update info """ - if isinstance(request, WeightUpdateRequest): + if isinstance(request, WeightTransferUpdateRequest): update_info_dict = request.update_info else: - raise TypeError(f"Expected WeightUpdateRequest, got {type(request)}") + raise TypeError( + f"Expected WeightTransferUpdateRequest, got {type(request)}" + ) await self.collective_rpc( "update_weights", kwargs={"update_info": update_info_dict} diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 5078b74dd5a0..28039b13181d 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -936,7 +936,7 @@ def save_tensorized_model( tensorizer_config=tensorizer_config, ) - def init_weight_transfer(self, init_info: dict) -> None: + def init_weight_transfer_engine(self, init_info: dict) -> None: """ Initialize weight transfer mechanism. For NCCL backend, this creates a process group with the trainer. @@ -951,7 +951,7 @@ def init_weight_transfer(self, init_info: dict) -> None: ) # Parse dict into backend-specific typed dataclass typed_init_info = self.weight_transfer_engine.parse_init_info(init_info) - self.weight_transfer_engine.init_transfer(typed_init_info) + self.weight_transfer_engine.init_transfer_engine(typed_init_info) def update_weights(self, update_info: dict) -> None: """ From a003e63b05edad43b761d4302145fd529b919745 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Thu, 29 Jan 2026 13:10:43 -0800 Subject: [PATCH 31/36] x Signed-off-by: ahao-anyscale --- examples/online_serving/rlhf_http.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/online_serving/rlhf_http.py b/examples/online_serving/rlhf_http.py index 461a60a1fb8e..b67ac336f4f2 100644 --- a/examples/online_serving/rlhf_http.py +++ b/examples/online_serving/rlhf_http.py @@ -13,7 +13,7 @@ Prerequisites: Start a vLLM server with weight transfer enabled: - $ vllm serve facebook/opt-125m \ + $ VLLM_SERVER_DEV_MODE=1 vllm serve facebook/opt-125m \ --enforce-eager \ --weight-transfer-config '{"backend": "nccl"}' \ --load-format dummy From 1aced0b3eacee08c21f911f377a19dc16220c6ec Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Fri, 30 Jan 2026 12:24:05 -0800 Subject: [PATCH 32/36] integrated layerwise reloading Signed-off-by: ahao-anyscale --- .../new_weight_syncing/rlhf.py | 8 ++-- vllm/distributed/weight_transfer/base.py | 8 +++- vllm/v1/worker/gpu_worker.py | 37 ++++++++++++++++--- 3 files changed, 42 insertions(+), 11 deletions(-) diff --git a/examples/offline_inference/new_weight_syncing/rlhf.py b/examples/offline_inference/new_weight_syncing/rlhf.py index 7fe5c7a66d06..90b253133d8b 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf.py +++ b/examples/offline_inference/new_weight_syncing/rlhf.py @@ -29,7 +29,6 @@ import os import ray -import torch from ray.util.placement_group import placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from transformers import AutoModelForCausalLM @@ -41,7 +40,8 @@ ) from vllm.utils.network_utils import get_ip, get_open_port -MODEL_NAME = "facebook/opt-125m" +# MODEL_NAME = "facebook/opt-125m" +MODEL_NAME = "inference-optimization/Qwen3-0.6B-W4A16-G128" class MyLLM(LLM): @@ -58,7 +58,7 @@ class TrainModel: def __init__(self, model_name: str): self.model = AutoModelForCausalLM.from_pretrained( - model_name, dtype=torch.bfloat16 + model_name, ).to("cuda:0") self.port = get_open_port() self.master_address = get_ip() @@ -196,7 +196,7 @@ def broadcast_weights(self, packed: bool = True): ray.get([train_handle, inference_handle]) # Finalize the weight update -ray.get(llm.finalize_weight_update.remote()) +# ray.get(llm.finalize_weight_update.remote()) # Generate text with the updated model. The output is expected to be normal # because the weights are updated. diff --git a/vllm/distributed/weight_transfer/base.py b/vllm/distributed/weight_transfer/base.py index 69c0cbee9ce2..b87f190fcf7a 100644 --- a/vllm/distributed/weight_transfer/base.py +++ b/vllm/distributed/weight_transfer/base.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable -from dataclasses import dataclass, field +from dataclasses import KW_ONLY, dataclass, field from typing import Any, Generic, TypeVar import torch @@ -28,7 +28,11 @@ class WeightTransferInitInfo(ABC): # noqa: B024 class WeightTransferUpdateInfo(ABC): # noqa: B024 """Base class for backend-specific weight update info.""" - pass + _: KW_ONLY + is_checkpoint_format: bool = True + """Set to True if weights are in checkpoint/original model format and need + layerwise processing. Set to False if weights have already been processed + into kernel format (repacking, renaming, etc.).""" # API-level request classes (accept dicts for backend-agnostic serialization) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 295b7ecc2371..d98e52202060 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -972,11 +972,38 @@ def update_weights(self, update_info: dict) -> None: # Parse dict into backend-specific typed dataclass typed_update_info = self.weight_transfer_engine.parse_update_info(update_info) - # Receive and load weights incrementally to avoid OOM - self.weight_transfer_engine.receive_weights( - typed_update_info, - load_weights=self.model_runner.model.load_weights, - ) + model = self.model_runner.model + + if typed_update_info.is_checkpoint_format: + from vllm.model_executor.model_loader.reload import ( + finalize_layerwise_reload, + initialize_layerwise_reload, + ) + + # Use layerwise reload pattern for checkpoint format weights + with torch.device(self.device): + initialize_layerwise_reload(model) + self.weight_transfer_engine.receive_weights( + typed_update_info, + load_weights=model.load_weights, + ) + finalize_layerwise_reload(model, self.model_config) + else: + # Weights are already in kernel format, copy directly + def load_weights_direct( + weights: list[tuple[str, torch.Tensor]], + ) -> set[str]: + loaded = set() + for name, weight in weights: + param = model.get_parameter(name) + param.copy_(weight) + loaded.add(name) + return loaded + + self.weight_transfer_engine.receive_weights( + typed_update_info, + load_weights=load_weights_direct, + ) def finalize_weight_update(self) -> None: """ From cc4c67eed534e0b7370948987ca2e5f07696b3bb Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Fri, 30 Jan 2026 12:29:39 -0800 Subject: [PATCH 33/36] removed finalize weight update Signed-off-by: ahao-anyscale --- .../new_weight_syncing/rlhf.py | 7 ++--- .../new_weight_syncing/rlhf_async_new_apis.py | 7 ++--- examples/online_serving/rlhf_http.py | 10 ------- .../entrypoints/openai/test_openai_schema.py | 1 - .../test_weight_transfer_llm.py | 28 +------------------ vllm/engine/protocol.py | 5 ---- vllm/entrypoints/llm.py | 7 ----- vllm/entrypoints/serve/rlhf/api_router.py | 6 ---- vllm/v1/engine/async_llm.py | 6 ---- vllm/v1/worker/gpu_worker.py | 11 -------- 10 files changed, 5 insertions(+), 83 deletions(-) diff --git a/examples/offline_inference/new_weight_syncing/rlhf.py b/examples/offline_inference/new_weight_syncing/rlhf.py index 90b253133d8b..bfc6b2df952a 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf.py +++ b/examples/offline_inference/new_weight_syncing/rlhf.py @@ -117,8 +117,8 @@ def broadcast_weights(self, packed: bool = True): # Launch the vLLM inference engine. The `enforce_eager` flag reduces # start-up latency. -# Note: Weight transfer APIs (init_weight_transfer_engine, update_weights, -# finalize_weight_update) are now native to vLLM workers. +# Note: Weight transfer APIs (init_weight_transfer_engine, update_weights) +# are now native to vLLM workers. llm = ray.remote( num_cpus=0, num_gpus=0, @@ -195,9 +195,6 @@ def broadcast_weights(self, packed: bool = True): train_handle = train_model.broadcast_weights.remote(packed=True) ray.get([train_handle, inference_handle]) -# Finalize the weight update -# ray.get(llm.finalize_weight_update.remote()) - # Generate text with the updated model. The output is expected to be normal # because the weights are updated. outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params)) diff --git a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py index 5cdfe0adc6dd..835c16a7f55c 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py +++ b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py @@ -156,8 +156,8 @@ def broadcast_weights(self, packed: bool = True): # Launch the vLLM inference engine. The `enforce_eager` flag reduces # start-up latency. -# Note: Weight transfer APIs (init_weight_transfer_engine, update_weights, -# finalize_weight_update) are now native to vLLM workers. +# Note: Weight transfer APIs (init_weight_transfer_engine, update_weights) +# are now native to vLLM workers. llm = ray.remote( num_cpus=0, num_gpus=0, @@ -248,9 +248,6 @@ def broadcast_weights(self, packed: bool = True): train_handle = train_model.broadcast_weights.remote(packed=True) ray.get([train_handle, inference_handle]) -# Finalize the weight update (processes weights for quantization/kernel format) -ray.get(llm.finalize_weight_update.remote()) - # Resume generation since weight sync is complete ray.get(llm.resume_generation.remote()) diff --git a/examples/online_serving/rlhf_http.py b/examples/online_serving/rlhf_http.py index b67ac336f4f2..721a038a6600 100644 --- a/examples/online_serving/rlhf_http.py +++ b/examples/online_serving/rlhf_http.py @@ -103,13 +103,6 @@ def update_weights( response.raise_for_status() -def finalize_weight_update(base_url: str) -> None: - """Finalize weight update via HTTP endpoint.""" - url = f"{base_url}/finalize_weight_update" - response = requests.post(url, timeout=60) - response.raise_for_status() - - def pause_generation(base_url: str) -> None: """Pause generation via HTTP endpoint.""" url = f"{base_url}/pause" @@ -230,9 +223,6 @@ def main(): # Wait for update_weights to complete update_thread.join() - # Finalize the weight update (processes weights for quantization/kernel format) - finalize_weight_update(BASE_URL) - # Resume generation after weight sync resume_generation(BASE_URL) diff --git a/tests/entrypoints/openai/test_openai_schema.py b/tests/entrypoints/openai/test_openai_schema.py index 89e58f833fe8..1baab9934fdd 100644 --- a/tests/entrypoints/openai/test_openai_schema.py +++ b/tests/entrypoints/openai/test_openai_schema.py @@ -144,7 +144,6 @@ def test_openapi_stateless(case: schemathesis.Case): if case.operation.path in ( "/init_weight_transfer_engine", "/update_weights", - "/finalize_weight_update", ): return diff --git a/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py b/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py index 582ce5bd87ff..9f2309c765b5 100644 --- a/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py +++ b/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py @@ -222,32 +222,9 @@ def check_update_called(self): assert shapes == test_shapes -@create_new_process_for_each_test() -def test_finalize_weight_update_runs(): - """Test that finalize_weight_update completes without error.""" - if torch.cuda.device_count() < 1: - pytest.skip("Need at least 1 GPU for this test") - - with patch( - "vllm.v1.worker.gpu_worker.WeightTransferEngineFactory.create_engine", - mock_create_engine, - ): - llm = LLM( - model=MODEL_NAME, - enforce_eager=True, - load_format="dummy", - tensor_parallel_size=1, - weight_transfer_config=WeightTransferConfig(backend="nccl"), - ) - - # finalize_weight_update should run without error - # (it calls process_weights_after_loading internally) - llm.finalize_weight_update() - - @create_new_process_for_each_test() def test_full_weight_transfer_flow(): - """Test the complete weight transfer flow: init -> update -> finalize.""" + """Test the complete weight transfer flow: init -> update.""" if torch.cuda.device_count() < 1: pytest.skip("Need at least 1 GPU for this test") @@ -282,9 +259,6 @@ def test_full_weight_transfer_flow(): ) ) - # Step 3: Finalize - llm.finalize_weight_update() - # Verify the full flow completed def check_flow(self): engine = self.weight_transfer_engine diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index acae613751a7..2bde7ba9c144 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -10,7 +10,6 @@ WeightTransferInitRequest, WeightTransferUpdateRequest, ) - from vllm.inputs.data import PromptType, StreamingInput from vllm.lora.request import LoRARequest from vllm.outputs import PoolingRequestOutput, RequestOutput @@ -206,7 +205,3 @@ async def init_weight_transfer_engine( async def update_weights(self, request: WeightTransferUpdateRequest) -> None: """Batched weight update for RL training.""" raise NotImplementedError - - async def finalize_weight_update(self) -> None: - """Finalize the current weight update during RL training.""" - raise NotImplementedError diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a65eb4935c07..2f64e47b73b2 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1754,13 +1754,6 @@ def update_weights(self, request: WeightTransferUpdateRequest | dict) -> None: "update_weights", kwargs={"update_info": update_info_dict} ) - def finalize_weight_update(self) -> None: - """ - Finalize the weight update by processing weights for quantization/kernel format. - This should be called after all weight updates are complete. - """ - self.llm_engine.collective_rpc("finalize_weight_update") - def __repr__(self) -> str: """Return a transformers-style hierarchical view of the model.""" # Cache the result to avoid repeated collective_rpc calls diff --git a/vllm/entrypoints/serve/rlhf/api_router.py b/vllm/entrypoints/serve/rlhf/api_router.py index 87b17fe8ff9c..38461b147781 100644 --- a/vllm/entrypoints/serve/rlhf/api_router.py +++ b/vllm/entrypoints/serve/rlhf/api_router.py @@ -139,12 +139,6 @@ async def update_weights(raw_request: Request): return JSONResponse(content={"message": "Weights updated"}) -@router.post("/finalize_weight_update") -async def finalize_weight_update(raw_request: Request): - await engine_client(raw_request).finalize_weight_update() - return JSONResponse(content={"message": "Weight update finalized"}) - - @router.get("/get_world_size") async def get_world_size( raw_request: Request, diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index ddb1965352b4..957780ab768a 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -1054,9 +1054,3 @@ async def update_weights(self, request: WeightTransferUpdateRequest) -> None: await self.collective_rpc( "update_weights", kwargs={"update_info": update_info_dict} ) - - async def finalize_weight_update(self) -> None: - """ - Finalize the current weight update during RL training. - """ - await self.collective_rpc("finalize_weight_update") diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index d98e52202060..2570398cf8cc 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -1005,17 +1005,6 @@ def load_weights_direct( load_weights=load_weights_direct, ) - def finalize_weight_update(self) -> None: - """ - Finalize the weight update by processing weights for quantization/kernel format. - This should be called after all weight updates are complete. - """ - from vllm.model_executor.model_loader.utils import process_weights_after_loading - - process_weights_after_loading( - self.model_runner.model, self.model_config, self.device - ) - def shutdown(self) -> None: # has_kv_transfer_group can be None during interpreter shutdown. if ensure_kv_transfer_shutdown is not None: From f69383cb536cc272d4f175f916884c693b0dceb4 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Fri, 30 Jan 2026 13:46:10 -0800 Subject: [PATCH 34/36] fixes to online quant Signed-off-by: ahao-anyscale --- .../new_weight_syncing/rlhf.py | 29 +++++++++++++------ .../model_loader/reload/layerwise.py | 5 ++++ vllm/v1/worker/gpu_worker.py | 6 ++++ 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/examples/offline_inference/new_weight_syncing/rlhf.py b/examples/offline_inference/new_weight_syncing/rlhf.py index bfc6b2df952a..7ac13636b1a4 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf.py +++ b/examples/offline_inference/new_weight_syncing/rlhf.py @@ -11,12 +11,13 @@ The example performs the following steps: * Load the training model on one gpu (scheduled via ray) -* Initialize the inference model with dummy weights across +* Initialize the inference model with real weights across two gpus using vLLM's tensor parallelism and Ray placement groups. -* Generate gibberish from a list of prompts using the randomly initialized +* Zero out all inference model weights to demonstrate weight syncing. +* Generate gibberish from a list of prompts using the zeroed inference engine. -* Update the weights of the training model and broadcast the updated weights - to the inference engine by using a Ray collective RPC group. +* Broadcast the real weights from the training model to the inference + engine using a Ray collective RPC group. * Generating from the list of prompts after weight sync should result in sensible outputs. @@ -40,8 +41,8 @@ ) from vllm.utils.network_utils import get_ip, get_open_port -# MODEL_NAME = "facebook/opt-125m" -MODEL_NAME = "inference-optimization/Qwen3-0.6B-W4A16-G128" +MODEL_NAME = "facebook/opt-125m" +# MODEL_NAME = "inference-optimization/Qwen3-0.6B-W4A16-G128" class MyLLM(LLM): @@ -51,6 +52,10 @@ def __init__(self, *args, **kwargs): os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0,1" super().__init__(*args, **kwargs) + def zero_weights(self): + """Zero out all model weights to demonstrate weight syncing works.""" + self.llm_engine.collective_rpc("zero_weights") + @ray.remote(num_gpus=1) class TrainModel: @@ -60,6 +65,7 @@ def __init__(self, model_name: str): self.model = AutoModelForCausalLM.from_pretrained( model_name, ).to("cuda:0") + self.port = get_open_port() self.master_address = get_ip() @@ -130,9 +136,14 @@ def broadcast_weights(self, packed: bool = True): data_parallel_size=1, distributed_executor_backend="ray", weight_transfer_config=WeightTransferConfig(backend="nccl"), - load_format="dummy", + quantization="fp8", ) +# Zero out all model weights to demonstrate that weight syncing works. +# After zeroing, the model will generate garbage. After syncing, it will +# generate sensible outputs. +ray.get(llm.zero_weights.remote()) + # Generate text from the prompts. prompts = [ "Hello, my name is", @@ -145,8 +156,8 @@ def broadcast_weights(self, packed: bool = True): outputs = ray.get(llm.generate.remote(prompts, sampling_params)) -# Generate text with the initial model. The output is expected to be nonsense -# because the weights are randomly initialized. +# Generate text with the zeroed model. The output is expected to be nonsense +# because all weights have been zeroed out. print("-" * 50) for output in outputs: prompt = output.prompt diff --git a/vllm/model_executor/model_loader/reload/layerwise.py b/vllm/model_executor/model_loader/reload/layerwise.py index f7aaf8a677c4..21795e63995e 100644 --- a/vllm/model_executor/model_loader/reload/layerwise.py +++ b/vllm/model_executor/model_loader/reload/layerwise.py @@ -216,6 +216,11 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo): # Materialize layer tensors onto device materialize_layer(layer) + # Reset FP8 online quantization flag so process_weights_after_loading + # will run again during reload + if hasattr(layer, "_already_called_process_weights_after_loading"): + delattr(layer, "_already_called_process_weights_after_loading") + # Unwrap layerwise loading wrappers for param in get_layer_tensors(layer).values(): param.weight_loader = _get_original_loader(param) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 2570398cf8cc..a90f11db3cb9 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -1005,6 +1005,12 @@ def load_weights_direct( load_weights=load_weights_direct, ) + def zero_weights(self) -> None: + """Zero out all model weights. Useful for demonstrating weight syncing.""" + model = self.model_runner.model + for param in model.parameters(): + param.data.zero_() + def shutdown(self) -> None: # has_kv_transfer_group can be None during interpreter shutdown. if ensure_kv_transfer_shutdown is not None: From 56249b33f56f5b3fe1882e8a712bd39dbc675974 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Fri, 30 Jan 2026 15:31:12 -0800 Subject: [PATCH 35/36] fix examples Signed-off-by: ahao-anyscale --- .../new_weight_syncing/rlhf.py | 23 ++++++------------- .../model_executor/layers/quantization/fp8.py | 9 ++++++++ .../layers/quantization/ipex_quant.py | 6 +++++ vllm/v1/worker/gpu_worker.py | 6 ----- 4 files changed, 22 insertions(+), 22 deletions(-) diff --git a/examples/offline_inference/new_weight_syncing/rlhf.py b/examples/offline_inference/new_weight_syncing/rlhf.py index 7ac13636b1a4..b3a3ca62f5a6 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf.py +++ b/examples/offline_inference/new_weight_syncing/rlhf.py @@ -11,13 +11,12 @@ The example performs the following steps: * Load the training model on one gpu (scheduled via ray) -* Initialize the inference model with real weights across +* Initialize the inference model with dummy weights across two gpus using vLLM's tensor parallelism and Ray placement groups. -* Zero out all inference model weights to demonstrate weight syncing. -* Generate gibberish from a list of prompts using the zeroed +* Generate gibberish from a list of prompts using the randomly initialized inference engine. -* Broadcast the real weights from the training model to the inference - engine using a Ray collective RPC group. +* Update the weights of the training model and broadcast the updated weights + to the inference engine by using a Ray collective RPC group. * Generating from the list of prompts after weight sync should result in sensible outputs. @@ -52,10 +51,6 @@ def __init__(self, *args, **kwargs): os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0,1" super().__init__(*args, **kwargs) - def zero_weights(self): - """Zero out all model weights to demonstrate weight syncing works.""" - self.llm_engine.collective_rpc("zero_weights") - @ray.remote(num_gpus=1) class TrainModel: @@ -136,14 +131,10 @@ def broadcast_weights(self, packed: bool = True): data_parallel_size=1, distributed_executor_backend="ray", weight_transfer_config=WeightTransferConfig(backend="nccl"), + load_format="dummy", quantization="fp8", ) -# Zero out all model weights to demonstrate that weight syncing works. -# After zeroing, the model will generate garbage. After syncing, it will -# generate sensible outputs. -ray.get(llm.zero_weights.remote()) - # Generate text from the prompts. prompts = [ "Hello, my name is", @@ -156,8 +147,8 @@ def broadcast_weights(self, packed: bool = True): outputs = ray.get(llm.generate.remote(prompts, sampling_params)) -# Generate text with the zeroed model. The output is expected to be nonsense -# because all weights have been zeroed out. +# Generate text with the initial model. The output is expected to be nonsense +# because the weights are randomly initialized. print("-" * 50) for output in outputs: prompt = output.prompt diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 6436a9ae0abf..bbdedcd08ab7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -632,6 +632,9 @@ def process_weights_after_loading(self, layer: Module) -> None: ) # Activations not quantized for marlin. + # Prevent duplicate processing (e.g., during weight reload) + layer._already_called_process_weights_after_loading = True + class Fp8MoEMethod(FusedMoEMethodBase): """MoE method for FP8. @@ -891,6 +894,9 @@ def process_weights_after_loading(self, layer: Module) -> None: layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale ) + # Prevent duplicate processing (e.g., during weight reload) + layer._already_called_process_weights_after_loading = True + def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, @@ -1164,6 +1170,9 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w2_input_scale, ) + # Prevent duplicate processing (e.g., during weight reload) + layer._already_called_process_weights_after_loading = True + class Fp8KVCacheMethod(BaseKVCacheMethod): """ diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index f957b3991eed..3993c7d32293 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -315,6 +315,9 @@ def process_weights_after_loading(self, layer: Module) -> None: replace_parameter(layer, "weight_scale", weight_scale.data) layer.input_scale = None + # Prevent duplicate processing (e.g., during weight reload) + layer._already_called_process_weights_after_loading = True + def apply( self, layer: torch.nn.Module, @@ -376,6 +379,9 @@ def process_weights_after_loading(self, layer: Module) -> None: experts_start_id=ep_rank_start, ) + # Prevent duplicate processing (e.g., during weight reload) + layer._already_called_process_weights_after_loading = True + def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index a90f11db3cb9..2570398cf8cc 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -1005,12 +1005,6 @@ def load_weights_direct( load_weights=load_weights_direct, ) - def zero_weights(self) -> None: - """Zero out all model weights. Useful for demonstrating weight syncing.""" - model = self.model_runner.model - for param in model.parameters(): - param.data.zero_() - def shutdown(self) -> None: # has_kv_transfer_group can be None during interpreter shutdown. if ensure_kv_transfer_shutdown is not None: From 669b24ca1da3a556eede1194deb06cf73759db01 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Fri, 30 Jan 2026 15:46:10 -0800 Subject: [PATCH 36/36] x Signed-off-by: ahao-anyscale --- vllm/v1/worker/gpu_worker.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 2570398cf8cc..38c0f10f22ca 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -992,13 +992,10 @@ def update_weights(self, update_info: dict) -> None: # Weights are already in kernel format, copy directly def load_weights_direct( weights: list[tuple[str, torch.Tensor]], - ) -> set[str]: - loaded = set() + ) -> None: for name, weight in weights: param = model.get_parameter(name) param.copy_(weight) - loaded.add(name) - return loaded self.weight_transfer_engine.receive_weights( typed_update_info,