-
-
Notifications
You must be signed in to change notification settings - Fork 15k
[Feat][RL][1/2] Native Weight Syncing API: NCCL #31943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
robertgshaw2-redhat
merged 56 commits into
vllm-project:main
from
hao-aaron:weight_transfer
Feb 5, 2026
Merged
Changes from all commits
Commits
Show all changes
56 commits
Select commit
Hold shift + click to select a range
21ef3d4
Weight transfer feature: incremental loading, async APIs, IPC/NCCL en…
hao-aaron 037f968
updated async, added world size endpoints
hao-aaron 676c7e3
bugfixes
hao-aaron 0a935dd
ipc fix
hao-aaron 1f29fcd
moved rlhf scripts
hao-aaron 3471efb
added tests
hao-aaron 26249eb
precommit fix
hao-aaron 0d4e296
added packed tensors
hao-aaron dae6946
added unit tests to CI
hao-aaron 7849211
added env variables, fixes
hao-aaron 4638bd4
Merge branch 'main' into weight_transfer
hao-aaron 437b14f
precommit fix
hao-aaron 6a235f9
Merge branch 'main' into weight_transfer
hao-aaron fc43e5b
test fixes
hao-aaron d6b4b88
Merge branch 'main' into weight_transfer
hao-aaron e2dc668
fix examples
hao-aaron c736f59
Merge branch 'main' into weight_transfer
hao-aaron 50b6039
x
hao-aaron e17f235
Merge branch 'weight_transfer' of github.com:ahao-anyscale/vllm into …
hao-aaron fda8819
Merge branch 'main' into weight_transfer
hao-aaron 9b341ff
Merge branch 'main' into weight_transfer
hao-aaron 7a25543
Merge branch 'main' into weight_transfer
hao-aaron 7b30911
Merge branch 'main' into weight_transfer
hao-aaron 8de0daa
x
hao-aaron ddb178d
x
hao-aaron abb69fb
Merge branch 'main' into weight_transfer
hao-aaron 4d69ed3
removed ipc
hao-aaron 892b736
edit examples to start with random weights, then weight sync to train…
hao-aaron 27a1441
added weight transfer factory
hao-aaron bbc13e9
change config
hao-aaron d44f46f
x
hao-aaron 844a84e
x
hao-aaron cb96ccd
x
hao-aaron 6fb9777
x
hao-aaron baf5bcf
x
hao-aaron 5fa3896
Merge branch 'main' into weight_transfer
hao-aaron 71364e1
x
hao-aaron de6a6ca
Merge branch 'weight_transfer' of https://github.com/ahao-anyscale/vl…
hao-aaron ee9a5b5
x
hao-aaron 19ed12a
x
hao-aaron 837d7ac
Merge branch 'main' into weight_transfer
hao-aaron a4b4239
x
hao-aaron 09a66de
Merge branch 'main' into weight_transfer
hao-aaron 6810282
Merge branch 'main' into weight_transfer
hao-aaron ab16e64
Merge branch 'main' into weight_transfer
hao-aaron c7e89f7
x
hao-aaron a003e63
x
hao-aaron ac2b879
Merge branch 'main' into weight_transfer
hao-aaron 7f8bb62
Merge branch 'main' into weight_transfer
hao-aaron 1aced0b
integrated layerwise reloading
hao-aaron cc4c67e
removed finalize weight update
hao-aaron f69383c
fixes to online quant
hao-aaron 56249b3
fix examples
hao-aaron 669b24c
x
hao-aaron a2b39df
Merge upstream/main into weight_transfer
hao-aaron c280dbc
Merge upstream vllm up to 9f14c9224
hao-aaron File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,208 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """ | ||
| Demonstrates reinforcement learning using vLLM and Ray, | ||
| with native weight syncing APIs at engine instance. | ||
|
|
||
| The script separates training and inference workloads onto distinct GPUs | ||
| so that Ray can manage process placement and inter-process communication. | ||
| A Hugging Face Transformer model occupies one GPU for training, whereas a | ||
| 2x tensor-parallel vLLM inference engine occupies two GPUs. | ||
|
|
||
| The example performs the following steps: | ||
| * Load the training model on one gpu (scheduled via ray) | ||
| * Initialize the inference model with dummy weights across | ||
| two gpus using vLLM's tensor parallelism and Ray placement groups. | ||
| * Generate gibberish from a list of prompts using the randomly initialized | ||
| inference engine. | ||
| * Update the weights of the training model and broadcast the updated weights | ||
| to the inference engine by using a Ray collective RPC group. | ||
| * Generating from the list of prompts after weight sync should result | ||
| in sensible outputs. | ||
|
|
||
| This example assumes a single-node cluster with three GPUs, but Ray | ||
| supports multi-node clusters. vLLM expects the GPUs are only used for vLLM | ||
| workloads. Residual GPU activity interferes with vLLM memory profiling and | ||
| causes unexpected behavior. | ||
| """ | ||
hao-aaron marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| import os | ||
|
|
||
| import ray | ||
| from ray.util.placement_group import placement_group | ||
| from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy | ||
| from transformers import AutoModelForCausalLM | ||
|
|
||
| from vllm import LLM, SamplingParams | ||
| from vllm.config import WeightTransferConfig | ||
| from vllm.distributed.weight_transfer.nccl_engine import ( | ||
| NCCLWeightTransferEngine, | ||
| ) | ||
| from vllm.utils.network_utils import get_ip, get_open_port | ||
|
|
||
| MODEL_NAME = "facebook/opt-125m" | ||
| # MODEL_NAME = "inference-optimization/Qwen3-0.6B-W4A16-G128" | ||
|
|
||
|
|
||
| class MyLLM(LLM): | ||
| """Configure the vLLM worker for Ray placement group execution.""" | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0,1" | ||
| super().__init__(*args, **kwargs) | ||
|
|
||
|
|
||
| @ray.remote(num_gpus=1) | ||
| class TrainModel: | ||
| """Ray actor that wraps the training model on a dedicated GPU.""" | ||
|
|
||
| def __init__(self, model_name: str): | ||
| self.model = AutoModelForCausalLM.from_pretrained( | ||
| model_name, | ||
| ).to("cuda:0") | ||
|
|
||
| self.port = get_open_port() | ||
| self.master_address = get_ip() | ||
|
|
||
| def get_master_address_and_port(self): | ||
| return self.master_address, self.port | ||
|
|
||
| def get_weight_metadata(self): | ||
| """Return weight names, dtypes, and shapes for weight transfer.""" | ||
| names = [] | ||
| dtype_names = [] | ||
| shapes = [] | ||
| for name, p in self.model.named_parameters(): | ||
| names.append(name) | ||
| dtype_names.append(str(p.dtype).split(".")[-1]) | ||
| shapes.append(list(p.shape)) | ||
| return names, dtype_names, shapes | ||
|
|
||
| def init_weight_transfer_group(self, world_size): | ||
| """Initialize the NCCL process group for weight transfer.""" | ||
| self.model_update_group = NCCLWeightTransferEngine.trainer_init( | ||
| dict( | ||
| master_address=self.master_address, | ||
| master_port=self.port, | ||
| world_size=world_size, | ||
| ), | ||
| ) | ||
|
|
||
| def broadcast_weights(self, packed: bool = True): | ||
| """Broadcast weights to the inference engine.""" | ||
| NCCLWeightTransferEngine.trainer_send_weights( | ||
| iterator=self.model.named_parameters(), | ||
| group=self.model_update_group, | ||
| packed=packed, | ||
| ) | ||
|
|
||
|
|
||
| # Initialize Ray and set the visible devices. The vLLM engine will | ||
| # be placed on GPUs 1 and 2. | ||
| ray.init() | ||
|
|
||
| # Create a placement group that reserves GPU 1–2 for the vLLM inference engine. | ||
| # Learn more about Ray placement groups: | ||
| # https://docs.ray.io/en/latest/placement-groups.html | ||
| # Launch the training model actor. Ray's resource scheduler will allocate | ||
| # 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs. | ||
| train_model = TrainModel.remote(MODEL_NAME) | ||
|
|
||
| pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2) | ||
| ray.get(pg_inference.ready()) | ||
| scheduling_inference = PlacementGroupSchedulingStrategy( | ||
| placement_group=pg_inference, | ||
| placement_group_capture_child_tasks=True, | ||
| placement_group_bundle_index=0, | ||
| ) | ||
|
|
||
| # Launch the vLLM inference engine. The `enforce_eager` flag reduces | ||
| # start-up latency. | ||
| # Note: Weight transfer APIs (init_weight_transfer_engine, update_weights) | ||
| # are now native to vLLM workers. | ||
| llm = ray.remote( | ||
| num_cpus=0, | ||
| num_gpus=0, | ||
| scheduling_strategy=scheduling_inference, | ||
| )(MyLLM).remote( | ||
| model=MODEL_NAME, | ||
| enforce_eager=True, | ||
| tensor_parallel_size=2, | ||
| data_parallel_size=1, | ||
| distributed_executor_backend="ray", | ||
| weight_transfer_config=WeightTransferConfig(backend="nccl"), | ||
| load_format="dummy", | ||
| quantization="fp8", | ||
| ) | ||
hao-aaron marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Generate text from the prompts. | ||
| prompts = [ | ||
| "Hello, my name is", | ||
| "The president of the United States is", | ||
| "The capital of France is", | ||
| "The future of AI is", | ||
| ] | ||
|
|
||
| sampling_params = SamplingParams(temperature=0) | ||
|
|
||
| outputs = ray.get(llm.generate.remote(prompts, sampling_params)) | ||
|
|
||
| # Generate text with the initial model. The output is expected to be nonsense | ||
| # because the weights are randomly initialized. | ||
| print("-" * 50) | ||
| for output in outputs: | ||
| prompt = output.prompt | ||
| generated_text = output.outputs[0].text | ||
| print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") | ||
| print("-" * 50) | ||
|
|
||
| # Set up the communication channel between the training process and the | ||
| # inference engine. | ||
| master_address, master_port = ray.get(train_model.get_master_address_and_port.remote()) | ||
|
|
||
| world_size = ray.get(llm.get_world_size.remote()) + 1 # +1 for the trainer | ||
| inference_handle = llm.init_weight_transfer_engine.remote( | ||
| dict( | ||
| init_info=dict( | ||
| master_address=master_address, | ||
| master_port=master_port, | ||
| rank_offset=1, | ||
| world_size=world_size, | ||
| ) | ||
| ) | ||
| ) | ||
|
|
||
| # Initialize weight transfer group on both the training actor and inference engine | ||
| train_handle = train_model.init_weight_transfer_group.remote(world_size) | ||
| ray.get([train_handle, inference_handle]) | ||
|
|
||
| # Synchronize the updated weights to the inference engine using batched API. | ||
| # Collect all weight metadata from the training actor | ||
| names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote()) | ||
|
|
||
| # Issue update_weights call with NCCL-specific update info | ||
| # packed=True enables efficient batched tensor broadcasting | ||
| inference_handle = llm.update_weights.remote( | ||
| dict( | ||
| update_info=dict( | ||
| names=names, | ||
| dtype_names=dtype_names, | ||
| shapes=shapes, | ||
| packed=True, | ||
| ) | ||
| ) | ||
| ) | ||
|
|
||
| # Broadcast all weights from trainer using the weight transfer API | ||
| train_handle = train_model.broadcast_weights.remote(packed=True) | ||
| ray.get([train_handle, inference_handle]) | ||
|
|
||
| # Generate text with the updated model. The output is expected to be normal | ||
| # because the weights are updated. | ||
| outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params)) | ||
| print("-" * 50) | ||
| for output in outputs_updated: | ||
| prompt = output.prompt | ||
| generated_text = output.outputs[0].text | ||
| print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") | ||
| print("-" * 50) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.