-
-
Notifications
You must be signed in to change notification settings - Fork 15k
[docs] Add docs for new RL flows #36188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
76e9da5
docs
hao-aaron b2558fc
Merge remote-tracking branch 'upstream/main' into rl-docs
hao-aaron da84089
merge
hao-aaron dbcd3e4
x
hao-aaron a28f138
x
hao-aaron ba5909a
moved examples
hao-aaron b64d66b
Add missing title subs
hmellor f5b03a2
Use title method for naming examples subdirs
hmellor 6d02124
Merge remote-tracking branch 'upstream/main' into rl-docs
hao-aaron 888630b
removed legacy weight sync
hao-aaron 177136d
Update distributed.yaml
hmellor 713ef48
pre-commit
hmellor 9b59ed5
Merge branch 'main' into rl-docs
hmellor File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| # Async Reinforcement Learning | ||
|
|
||
| ## Overview | ||
|
|
||
| In a standard RL training loop, generation and training happen sequentially: the policy generates rollouts, then training runs on those rollouts, and the cycle repeats. During generation the training accelerators sit idle, and vice versa. | ||
|
|
||
| The **one-off pipelining** approach separates the generation and training phases into two parallel coroutines, allowing the model to generate new samples while simultaneously training on previously generated data. This can lead to better GPU utilization and greater training throughput. | ||
|
|
||
| However, this overlap introduces a complication: weights must be updated in the inference engine mid-flight, while requests may still be in progress. | ||
|
|
||
| ## The Pause and Resume API | ||
|
|
||
| To safely update weights while the inference engine is running, vLLM provides `pause_generation` and `resume_generation` methods. These let the trainer coordinate a clean window for weight synchronization without losing in-flight work. | ||
|
|
||
| ### pause_generation | ||
|
|
||
| ```python | ||
| await engine.pause_generation(mode="keep", clear_cache=True) | ||
| ``` | ||
|
|
||
| The `mode` parameter controls how in-flight requests are handled: | ||
|
|
||
| | Mode | Behavior | | ||
| | ---- | -------- | | ||
| | `"abort"` | Abort all in-flight requests immediately and return partial results (default) | | ||
| | `"wait"` | Wait for all in-flight requests to finish before pausing | | ||
| | `"keep"` | Freeze requests in the queue; they resume when `resume_generation` is called | | ||
|
|
||
| The `clear_cache` parameter controls whether to clear the KV cache and prefix cache after pausing. | ||
|
|
||
| ### resume_generation | ||
|
|
||
| ```python | ||
| await engine.resume_generation() | ||
| ``` | ||
|
|
||
| Resumes the scheduler after a pause. Any requests frozen with `mode="keep"` will continue generating. | ||
|
|
||
| ### HTTP Endpoints | ||
|
|
||
| When using the vLLM HTTP server, the same functionality is available via: | ||
|
|
||
| - `POST /pause?mode=keep` - Pause generation | ||
| - `POST /resume` - Resume generation | ||
|
|
||
| !!! note "Data Parallelism" | ||
| When using data parallelism with vLLM's **internal load balancer** (i.e. `data_parallel_backend="ray"`), pause and resume are handled automatically across all DP ranks -- a single call is sufficient. When using an **external load balancer** (i.e. multiple independent vLLM instances behind a proxy), you must send pause and resume requests to **every** engine instance individually before and after the weight update. | ||
|
|
||
| ## Typical Async RL Flow | ||
|
|
||
| A typical async RL loop with weight syncing looks like this: | ||
|
|
||
| 1. Start generating rollouts from the current policy | ||
| 2. Once trainer has new weights to update to, pause generation with `mode="keep"` | ||
| 3. Sync the updated weights from the trainer to the inference engine (see [Weight Transfer](weight_transfer/README.md)) | ||
| 4. Resume generation -- in-flight requests continue with the new weights | ||
| 5. Repeat | ||
|
|
||
| The key insight is that requests paused with `mode="keep"` will produce tokens from the **old** weights before the pause and tokens from the **new** weights after resume. The `clear_cache` parameter controls whether the KV cache is invalidated during the pause. When `clear_cache=True`, previously cached key-value entries are discarded, so all tokens generated after resume will be computed entirely with the new weights. When `clear_cache=False`, existing KV cache entries are retained, meaning some tokens in context may still reflect the old weights (stale KV cache). | ||
|
|
||
| ## Example | ||
|
|
||
| The [async RLHF example](../examples/rl/rlhf_async_new_apis.md) demonstrates this pattern with `vllm.AsyncLLMEngine`, NCCL weight transfer, and mid-flight pause/resume with validation. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,78 @@ | ||
| # Weight Transfer | ||
|
|
||
| vLLM provides a pluggable weight transfer system for synchronizing model weights from a training process to the inference engine during reinforcement learning (RL) workflows. This is essential for RLHF, GRPO, and other online RL methods where the policy model is iteratively updated during training and the updated weights must be reflected in the inference engine for rollout generation. | ||
|
|
||
| ## Architecture | ||
|
|
||
| The weight transfer system follows a **two-phase protocol** with a pluggable backend design: | ||
|
|
||
| 1. **Initialization** (`init_weight_transfer_engine`): Establishes the communication channel between the trainer and inference workers. Called once before the training loop begins. | ||
| 2. **Weight Update** (`update_weights`): Transfers updated weights from the trainer to the inference engine. Called after each training step (or batch of steps). | ||
|
|
||
| ## Available Backends | ||
|
|
||
| | Backend | Transport | Use Case | | ||
| | ------- | --------- | -------- | | ||
| | [NCCL](nccl.md) | NCCL broadcast | Separate GPUs for training and inference | | ||
| | [IPC](ipc.md) | CUDA IPC handles | Colocated training and inference on same GPU | | ||
|
|
||
| ## Configuration | ||
|
|
||
| Specify the weight transfer backend through `WeightTransferConfig`. The backend determines which engine handles the weight synchronization. | ||
|
|
||
| ### Programmatic (Offline Inference) | ||
|
|
||
| ```python | ||
| from vllm import LLM | ||
| from vllm.config import WeightTransferConfig | ||
|
|
||
| llm = LLM( | ||
| model="my-model", | ||
| weight_transfer_config=WeightTransferConfig(backend="nccl"), # or "ipc" | ||
| ) | ||
| ``` | ||
|
|
||
| ### CLI (Online Serving) | ||
|
|
||
| ```bash | ||
| vllm serve my-model \ | ||
| --weight-transfer-config '{"backend": "nccl"}' | ||
| ``` | ||
|
|
||
| The `backend` field accepts `"nccl"` (default) or `"ipc"`. | ||
|
|
||
| ## API Endpoints | ||
|
|
||
| When running vLLM as an HTTP server, the following endpoints are available for weight transfer: | ||
|
|
||
| | Endpoint | Method | Description | | ||
| | -------- | ------ | ----------- | | ||
| | `/init_weight_transfer_engine` | POST | Initialize the weight transfer engine with backend-specific info | | ||
| | `/update_weights` | POST | Trigger a weight update with backend-specific metadata | | ||
| | `/pause` | POST | Pause generation before weight sync to handle inflight requests | | ||
| | `/resume` | POST | Resume generation after weight sync | | ||
| | `/get_world_size` | GET | Get the number of inference workers (useful for NCCL world size calculation) | | ||
|
|
||
| !!! note | ||
| The HTTP weight transfer endpoints require `VLLM_SERVER_DEV_MODE=1` to be set. | ||
|
|
||
| ## Trainer-Side API | ||
|
|
||
| Both backends provide static methods that the trainer calls to send weights. The general pattern is: | ||
|
|
||
| ```python | ||
| # 1. Initialize the transfer engine (backend-specific) | ||
| EngineClass.trainer_init(init_info) | ||
|
|
||
| # 2. Send weights to inference workers | ||
| EngineClass.trainer_send_weights( | ||
| iterator=model.named_parameters(), | ||
| trainer_args=backend_specific_args, | ||
| ) | ||
| ``` | ||
|
|
||
| See the [NCCL](nccl.md) and [IPC](ipc.md) pages for backend-specific trainer APIs and full examples. | ||
|
|
||
| ## Extending the System | ||
|
|
||
| The weight transfer system is designed to be extensible. You can implement custom backends by subclassing `WeightTransferEngine` and registering them with the factory. See the [Base Class](base.md) page for details. | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hao-aaron I just wanted to give you the option to recommend this, I think it looks nicer but it's up to you