Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .buildkite/test-amd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1573,7 +1573,7 @@ steps:
- tests/compile/fullgraph/test_basic_correctness.py
- examples/offline_inference/rlhf.py
- examples/offline_inference/rlhf_colocate.py
- examples/offline_inference/new_weight_syncing/
- examples/rl/
- tests/examples/offline_inference/data_parallel.py
- tests/v1/distributed
- tests/v1/engine/test_engine_core_client.py
Expand Down Expand Up @@ -1615,7 +1615,7 @@ steps:
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
- popd
# NEW rlhf examples
- pushd ../examples/offline_inference/new_weight_syncing
- pushd ../examples/rl
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_nccl.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py
Expand Down Expand Up @@ -2660,7 +2660,7 @@ steps:
- tests/v1/entrypoints/openai/test_multi_api_servers.py
- tests/v1/shutdown
- tests/v1/worker/test_worker_memory_snapshot.py
- examples/offline_inference/new_weight_syncing/
- examples/rl/
commands:
# Work around HIP bug tracked here: https://github.com/ROCm/hip/issues/3876
# TODO: Remove when the bug is fixed in a future ROCm release
Expand Down Expand Up @@ -3325,7 +3325,7 @@ steps:
- tests/compile/fullgraph/test_basic_correctness.py
- examples/offline_inference/rlhf.py
- examples/offline_inference/rlhf_colocate.py
- examples/offline_inference/new_weight_syncing/
- examples/rl/
- tests/examples/offline_inference/data_parallel.py
- tests/v1/distributed
- tests/v1/engine/test_engine_core_client.py
Expand Down Expand Up @@ -3367,7 +3367,7 @@ steps:
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
- popd
# NEW rlhf examples
- pushd ../examples/offline_inference/new_weight_syncing
- pushd ../examples/rl
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_nccl.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py
Expand Down
29 changes: 12 additions & 17 deletions .buildkite/test_areas/distributed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,41 +82,36 @@ steps:

- label: Distributed Torchrun + Examples (4 GPUs)
timeout_in_minutes: 30
working_dir: "/vllm-workspace/tests"
working_dir: "/vllm-workspace"
num_devices: 4
source_file_dependencies:
- vllm/distributed/
- tests/distributed/test_torchrun_example.py
- tests/distributed/test_torchrun_example_moe.py
- examples/offline_inference/rlhf.py
- examples/offline_inference/rlhf_colocate.py
- examples/offline_inference/new_weight_syncing/
- examples/rl/
- tests/examples/offline_inference/data_parallel.py
commands:
# https://github.com/NVIDIA/nccl/issues/1838
- export NCCL_CUMEM_HOST_ENABLE=0
# test with torchrun tp=2 and external_dp=2
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
- torchrun --nproc-per-node=4 tests/distributed/test_torchrun_example.py
# test with torchrun tp=2 and pp=2
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
- PP_SIZE=2 torchrun --nproc-per-node=4 tests/distributed/test_torchrun_example.py
# test with torchrun tp=4 and dp=1
- TP_SIZE=4 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
- TP_SIZE=4 torchrun --nproc-per-node=4 tests/distributed/test_torchrun_example_moe.py
# test with torchrun tp=2, pp=2 and dp=1
- PP_SIZE=2 TP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
- PP_SIZE=2 TP_SIZE=2 torchrun --nproc-per-node=4 tests/distributed/test_torchrun_example_moe.py
# test with torchrun tp=1 and dp=4 with ep
- DP_SIZE=4 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
- DP_SIZE=4 ENABLE_EP=1 torchrun --nproc-per-node=4 tests/distributed/test_torchrun_example_moe.py
# test with torchrun tp=2 and dp=2 with ep
- TP_SIZE=2 DP_SIZE=2 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
- TP_SIZE=2 DP_SIZE=2 ENABLE_EP=1 torchrun --nproc-per-node=4 tests/distributed/test_torchrun_example_moe.py
# test with internal dp
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
# OLD rlhf examples
- cd ../examples/offline_inference
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
# NEW rlhf examples
- cd new_weight_syncing
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_nccl.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py
- python3 examples/offline_inference/data_parallel.py --enforce-eager
# rlhf examples
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 examples/rl/rlhf_nccl.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 examples/rl/rlhf_ipc.py

- label: Distributed DP Tests (4 GPUs)
timeout_in_minutes: 30
Expand Down
10 changes: 9 additions & 1 deletion docs/mkdocs/hooks/generate_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,18 @@ def title(text: str) -> str:
# Custom substitutions
subs = {
"io": "IO",
"api": "API",
"rl": "RL",
"api(s?)": r"API\1",
"cli": "CLI",
"cpu": "CPU",
"ipc": "IPC",
"llm": "LLM",
"mae": "MAE",
"ner": "NER",
"tpu": "TPU",
"gguf": "GGUF",
"lora": "LoRA",
"nccl": "NCCL",
"rlhf": "RLHF",
"vllm": "vLLM",
"openai": "OpenAI",
Expand Down Expand Up @@ -196,6 +199,11 @@ def generate(self) -> str:


def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
# Monkey-patch dirname_to_title in awesome-nav so that sub-directory names are
# title-cased (e.g. "Offline Inference" instead of "Offline inference").
import mkdocs_awesome_nav.nav.directory as _nav_dir

_nav_dir.dirname_to_title = title
logger.info("Generating example documentation")
logger.debug("Root directory: %s", ROOT_DIR.resolve())
logger.debug("Example directory: %s", EXAMPLE_DIR.resolve())
Expand Down
63 changes: 63 additions & 0 deletions docs/training/async_rl.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Async Reinforcement Learning

## Overview

In a standard RL training loop, generation and training happen sequentially: the policy generates rollouts, then training runs on those rollouts, and the cycle repeats. During generation the training accelerators sit idle, and vice versa.

The **one-off pipelining** approach separates the generation and training phases into two parallel coroutines, allowing the model to generate new samples while simultaneously training on previously generated data. This can lead to better GPU utilization and greater training throughput.

However, this overlap introduces a complication: weights must be updated in the inference engine mid-flight, while requests may still be in progress.

## The Pause and Resume API

To safely update weights while the inference engine is running, vLLM provides `pause_generation` and `resume_generation` methods. These let the trainer coordinate a clean window for weight synchronization without losing in-flight work.

### pause_generation

```python
await engine.pause_generation(mode="keep", clear_cache=True)
```

The `mode` parameter controls how in-flight requests are handled:

| Mode | Behavior |
| ---- | -------- |
| `"abort"` | Abort all in-flight requests immediately and return partial results (default) |
| `"wait"` | Wait for all in-flight requests to finish before pausing |
| `"keep"` | Freeze requests in the queue; they resume when `resume_generation` is called |

The `clear_cache` parameter controls whether to clear the KV cache and prefix cache after pausing.

### resume_generation

```python
await engine.resume_generation()
```

Resumes the scheduler after a pause. Any requests frozen with `mode="keep"` will continue generating.

### HTTP Endpoints

When using the vLLM HTTP server, the same functionality is available via:

- `POST /pause?mode=keep` - Pause generation
- `POST /resume` - Resume generation

!!! note "Data Parallelism"
When using data parallelism with vLLM's **internal load balancer** (i.e. `data_parallel_backend="ray"`), pause and resume are handled automatically across all DP ranks -- a single call is sufficient. When using an **external load balancer** (i.e. multiple independent vLLM instances behind a proxy), you must send pause and resume requests to **every** engine instance individually before and after the weight update.

## Typical Async RL Flow

A typical async RL loop with weight syncing looks like this:

1. Start generating rollouts from the current policy
2. Once trainer has new weights to update to, pause generation with `mode="keep"`
3. Sync the updated weights from the trainer to the inference engine (see [Weight Transfer](weight_transfer/README.md))
4. Resume generation -- in-flight requests continue with the new weights
5. Repeat

The key insight is that requests paused with `mode="keep"` will produce tokens from the **old** weights before the pause and tokens from the **new** weights after resume. The `clear_cache` parameter controls whether the KV cache is invalidated during the pause. When `clear_cache=True`, previously cached key-value entries are discarded, so all tokens generated after resume will be computed entirely with the new weights. When `clear_cache=False`, existing KV cache entries are retained, meaning some tokens in context may still reflect the old weights (stale KV cache).

## Example

The [async RLHF example](../examples/rl/rlhf_async_new_apis.md) demonstrates this pattern with `vllm.AsyncLLMEngine`, NCCL weight transfer, and mid-flight pause/resume with validation.
6 changes: 2 additions & 4 deletions docs/training/rlhf.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@ The following open-source RL libraries use vLLM for fast rollouts (sorted alphab
- [Unsloth](https://github.com/unslothai/unsloth)
- [verl](https://github.com/volcengine/verl)

See the following basic examples to get started if you don't want to use an existing library:
For weight synchronization between training and inference, see the [Weight Transfer](weight_transfer/README.md) documentation, which covers the pluggable backend system with [NCCL](weight_transfer/nccl.md) (multi-GPU) and [IPC](weight_transfer/ipc.md) (same-GPU) engines.

- [Training and inference processes are located on separate GPUs (inspired by OpenRLHF)](../examples/offline_inference/rlhf.md)
- [Training and inference processes are colocated on the same GPUs using Ray](../examples/offline_inference/rlhf_colocate.md)
- [Utilities for performing RLHF with vLLM](../examples/offline_inference/rlhf_utils.md)
For pipelining generation and training to improve GPU utilization and throughput, see the [Async Reinforcement Learning](async_rl.md) guide, which covers the pause/resume API for safely updating weights mid-flight.

See the following notebooks showing how to use vLLM for GRPO:

Expand Down
78 changes: 78 additions & 0 deletions docs/training/weight_transfer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Weight Transfer

vLLM provides a pluggable weight transfer system for synchronizing model weights from a training process to the inference engine during reinforcement learning (RL) workflows. This is essential for RLHF, GRPO, and other online RL methods where the policy model is iteratively updated during training and the updated weights must be reflected in the inference engine for rollout generation.

## Architecture

The weight transfer system follows a **two-phase protocol** with a pluggable backend design:

1. **Initialization** (`init_weight_transfer_engine`): Establishes the communication channel between the trainer and inference workers. Called once before the training loop begins.
2. **Weight Update** (`update_weights`): Transfers updated weights from the trainer to the inference engine. Called after each training step (or batch of steps).

## Available Backends

| Backend | Transport | Use Case |
| ------- | --------- | -------- |
| [NCCL](nccl.md) | NCCL broadcast | Separate GPUs for training and inference |
| [IPC](ipc.md) | CUDA IPC handles | Colocated training and inference on same GPU |

## Configuration

Specify the weight transfer backend through `WeightTransferConfig`. The backend determines which engine handles the weight synchronization.

### Programmatic (Offline Inference)

```python
from vllm import LLM
from vllm.config import WeightTransferConfig

llm = LLM(
model="my-model",
weight_transfer_config=WeightTransferConfig(backend="nccl"), # or "ipc"
)
```

### CLI (Online Serving)

```bash
vllm serve my-model \
--weight-transfer-config '{"backend": "nccl"}'
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hao-aaron I just wanted to give you the option to recommend this, I think it looks nicer but it's up to you

Suggested change
--weight-transfer-config '{"backend": "nccl"}'
--weight-transfer-config.backend nccl

```

The `backend` field accepts `"nccl"` (default) or `"ipc"`.

## API Endpoints

When running vLLM as an HTTP server, the following endpoints are available for weight transfer:

| Endpoint | Method | Description |
| -------- | ------ | ----------- |
| `/init_weight_transfer_engine` | POST | Initialize the weight transfer engine with backend-specific info |
| `/update_weights` | POST | Trigger a weight update with backend-specific metadata |
| `/pause` | POST | Pause generation before weight sync to handle inflight requests |
| `/resume` | POST | Resume generation after weight sync |
| `/get_world_size` | GET | Get the number of inference workers (useful for NCCL world size calculation) |

!!! note
The HTTP weight transfer endpoints require `VLLM_SERVER_DEV_MODE=1` to be set.

## Trainer-Side API

Both backends provide static methods that the trainer calls to send weights. The general pattern is:

```python
# 1. Initialize the transfer engine (backend-specific)
EngineClass.trainer_init(init_info)

# 2. Send weights to inference workers
EngineClass.trainer_send_weights(
iterator=model.named_parameters(),
trainer_args=backend_specific_args,
)
```

See the [NCCL](nccl.md) and [IPC](ipc.md) pages for backend-specific trainer APIs and full examples.

## Extending the System

The weight transfer system is designed to be extensible. You can implement custom backends by subclassing `WeightTransferEngine` and registering them with the factory. See the [Base Class](base.md) page for details.
Loading
Loading