Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 13 additions & 27 deletions docs/source/vllm_integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood.

> [!WARNING]
> TRL currently only supports vLLM versions from `0.10.2` to `0.13.0`. Please ensure you have a version in this range installed to avoid compatibility issues.
> TRL currently only supports vLLM versions from `0.10.2` to `0.14.1`. Please ensure you have a version in this range installed to avoid compatibility issues.

> [!TIP]
> The following trainers currently support generation with vLLM:
Expand Down Expand Up @@ -31,12 +31,12 @@ pip install "trl[vllm]"
Then run the server on specific GPUs (e.g., GPUs 0-3):

```sh
CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 2 --data-parallel-size 2
CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 4
```

Once the server is running, you can use it to generate completions for training. In the example below, we are using the different supported trainers using the vLLM server for generation. The `--tensor-parallel-size` and `--data-parallel-size` arguments control how the model and data are sharded across GPUs.

In this example, we are sharding two copies of the model across 4 GPUs. Increasing data parallelism increases throughput, while increasing tensor parallelism allows for serving larger models. Then, run the training script on different GPUs (e.g., GPUs 4-7) by passing `use_vllm=True` in the training arguments as follows:
In this example, we shard one model across 4 GPUs with tensor parallelism. Then, run the training script on different GPUs (e.g., GPUs 4-7) by passing `use_vllm=True` in the training arguments as follows:

Sample of a simple `train.py` script:

Expand Down Expand Up @@ -166,19 +166,15 @@ If you've ever done autoregressive decoder training, you know all the input toke
When you run for example

```sh
CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 1 --data-parallel-size 4
CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 4
```

the following happens:
1. vLLM first spawns multiple workers to handle incoming requests in parallel. The number of workers is determined by multiplying the `--tensor-parallel-size` and `--data-parallel-size` values. In this example, it spawns 4 workers (4 × 1).
Each worker operates independently and processes a chunk of the incoming requests — which are basically the prompts sent to the server for generation.

![vllm](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/vllm-doc.png)
2. Once the incoming requests (prompts) are distributed across the workers, the model starts generating completions. Internally, the model’s weights are split across multiple GPUs based on the `--tensor-parallel-size` argument — this is how tensor parallelism is handled.

1. vLLM first spawns multiple workers to handle incoming requests in parallel. The number of workers is determined by multiplying the `--tensor-parallel-size` and `--data-parallel-size` values. In this example, it spawns 4 workers (1 × 4).
Each worker operates independently and processes a chunk of the incoming requests — which are basically the prompts sent to the server for generation. A key point to understand is that these 4 workers are running in parallel, and each one is responsible for handling a subset of the total incoming load.

2. Once the incoming requests (prompts) are distributed across the workers, the model starts generating completions. Internally, the model’s weights are split across multiple GPUs based on the `--tensor-parallel-size` argument — this is how tensor parallelism is handled. Meanwhile, data parallelism (controlled by `--data-parallel-size`) ensures that different sets of requests are processed independently across the workers. In short: tensor parallelism splits the model across GPUs, and data parallelism splits the batch of requests across different model replicas.

3. Although the GPUs process requests independently and in parallel, they still need to communicate with each other. Remember that each GPU handles only a slice of the incoming prompts (for example, with 4 GPUs and 8 prompts using `--data-parallel-size=4`, each GPU processes 2 prompts).
3. Although the GPUs process requests independently and in parallel, they still need to communicate with each other. Remember that each GPU handles only a slice of the incoming prompts (for example, with 4 GPUs and 8 prompts using `--tensor-parallel-size=4`, each GPU participates in serving the full model).
This GPU-to-GPU communication is managed efficiently by NVIDIA’s NCCL library. The communication mainly ensures that each GPU gets its correct portion of the incoming requests — it’s lightweight and doesn’t interfere with generation itself.
Separately, the number of completions to generate per prompt is controlled by the `num_generations` setting in the GRPO config. For instance, if you set `num_generations=2` (like in the picture above), each prompt will have 2 completions. So, with 8 prompts and `num_generations=2`, you would end up with 16 completions total — regardless of the number of GPUs or parallelism settings.

Expand Down Expand Up @@ -224,7 +220,9 @@ options:
--tensor_parallel_size TENSOR_PARALLEL_SIZE, --tensor-parallel-size TENSOR_PARALLEL_SIZE
Number of tensor parallel workers to use. (default: 1)
--data_parallel_size DATA_PARALLEL_SIZE, --data-parallel-size DATA_PARALLEL_SIZE
Number of data parallel workers to use. (default: 1)
Number of data parallel workers to use. For dense models, keep this at 1. Starting from vLLM `0.14.0`, setting
this above `1` for dense models is no longer supported/useful and will error out (see vLLM PR #30739).
(default: 1)
--host HOST Host address to run the server on. (default: 0.0.0.0)
--port PORT Port to run the server on. (default: 8000)
--gpu_memory_utilization GPU_MEMORY_UTILIZATION, --gpu-memory-utilization GPU_MEMORY_UTILIZATION
Expand Down Expand Up @@ -259,20 +257,8 @@ options:
![tp dp throughput 8 gpus](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tp_dp_throughput_8_gpus.png)
![tp dp throughput 4 gpus](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tp_dp_throughput_4_gpus.png)

First and foremost, always remember that the optimal setup depends on:

- The model size
- The number of GPUs you have
- The GPU memory size
- The batch size you are using
- The number of requests you are sending to the server (prompts)
- The `max_model_len` you are using (this is the max length of the input sequence that the model can process, a.k.a. the context window size)
- The number of completions you are generating for each request (`num_generations`)

Given these factors, our experiments on the Qwen model family (3B, 7B, 14B, 32B) using 8 H100 GPUs show that:

- For reasonable-sized models (3B–14B) and a moderate context window (`max_len < 8k`), using full capacity for data parallelism gives better throughput. The setup `(tp=1, dp=8)` yields the best results.
- For larger models (32B) and longer context windows (`max_len > 8k`), a smaller DP size combined with some model-side parallelism performs better. For example, `(tp=2, dp=4)` is a good setup for 32B models with a larger context window.
> [!WARNING]
> The benchmark plots above were collected with older vLLM versions. Starting with [vLLM PR #30739](https://github.com/vllm-project/vllm/pull/30739) (released in `0.14.0`), offline data parallel scaling for non-MoE (dense) models is no longer supported. To follow the latest recommendations, do not scale DP for non-MoE models.

### vLLM with Transformers Backend

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ test = [
"pytest"
]
vllm = [
"vllm>=0.10.2,<0.14.0",
"vllm>=0.10.2,<=0.14.1",
"fastapi",
"pydantic",
"requests",
Expand Down
50 changes: 50 additions & 0 deletions tests/test_vllm_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from types import SimpleNamespace

import pytest
from packaging.version import Version
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.testing_utils import torch_device

Expand All @@ -35,8 +36,13 @@


if is_vllm_available():
import vllm
from vllm import LLM, SamplingParams

_is_vllm_ge_014 = Version(vllm.__version__) >= Version("0.14.0")
else:
_is_vllm_ge_014 = False


class TestChunkList(TrlTestCase):
def test_even_split(self):
Expand Down Expand Up @@ -530,6 +536,26 @@ def multiply(a: int, b: int) -> int:
decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0])
assert "Multiplies two integers." in decoded_prompt

def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
"completion_ids"
]

# Check that the output is a list
assert isinstance(completion_ids, list)

# Check that the number of generated sequences is 2 times the number of prompts
assert len(completion_ids) == 2 * len(prompts)

# Check that the generated sequences are lists of integers
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)

# Check that the length of the generated sequences is less than or equal to 32
for seq in completion_ids:
assert len(seq) <= 32

Comment on lines +539 to +558
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not specific to vllm 0.14, but I realized that this test case was missing

def test_update_model_params(self):
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
self.client.update_model_params(model)
Expand All @@ -549,6 +575,10 @@ def teardown_class(cls):


@pytest.mark.slow
@pytest.mark.skipif(
_is_vllm_ge_014,
reason="Skipping DP server test for vLLM>=0.14.0 (PR vllm#30739: DP for non-MoE/dense models no longer supported).",
)
@require_3_accelerators
@require_vllm
class TestVLLMClientServerDP(TrlTestCase):
Expand Down Expand Up @@ -635,6 +665,26 @@ def multiply(a: int, b: int) -> int:
decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0])
assert "Multiplies two integers." in decoded_prompt

def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
"completion_ids"
]

# Check that the output is a list
assert isinstance(completion_ids, list)

# Check that the number of generated sequences is 2 times the number of prompts
assert len(completion_ids) == 2 * len(prompts)

# Check that the generated sequences are lists of integers
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)

# Check that the length of the generated sequences is less than or equal to 32
for seq in completion_ids:
assert len(seq) <= 32

Comment on lines +668 to +687
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def test_update_model_params(self):
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
self.client.update_model_params(model)
Expand Down
2 changes: 1 addition & 1 deletion trl/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _patch_vllm_disabled_tqdm() -> None:

- Bug introduced in https://github.com/vllm-project/vllm/pull/52
- Fixed in https://github.com/vllm-project/vllm/pull/28471 (released in v0.11.1)
- Since TRL currently supports vLLM v0.10.2-0.13.0, we patch it here
- Since TRL currently supports vLLM v0.10.2-0.14.1, we patch it here
- This can be removed when TRL requires vLLM>=0.11.1
"""
if _is_package_version_below("vllm", "0.11.1"):
Expand Down
4 changes: 2 additions & 2 deletions trl/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ def is_uvicorn_available() -> bool:
def is_vllm_available() -> bool:
_vllm_available, _vllm_version = _is_package_available("vllm", return_version=True)
if _vllm_available:
if not (Version("0.10.2") <= Version(_vllm_version) <= Version("0.13.0")):
if not (Version("0.10.2") <= Version(_vllm_version) <= Version("0.14.1")):
warnings.warn(
f"TRL currently supports vLLM versions from 0.10.2 to 0.13.0. You have version {_vllm_version} "
f"TRL currently supports vLLM versions from 0.10.2 to 0.14.1. You have version {_vllm_version} "
"installed. We recommend installing a supported version to avoid compatibility issues.",
stacklevel=2,
)
Expand Down
10 changes: 8 additions & 2 deletions trl/scripts/vllm_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ class ScriptArguments:
tensor_parallel_size (`int`, *optional*, defaults to `1`):
Number of tensor parallel workers to use.
data_parallel_size (`int`, *optional*, defaults to `1`):
Number of data parallel workers to use.
Number of data parallel workers to use. For dense models, keep this at 1. Starting from vLLM `0.14.0`,
setting this above `1` for dense models is no longer supported/useful and will error out (see vLLM PR
#30739).
host (`str`, *optional*, defaults to `"0.0.0.0"`):
Host address to run the server on.
port (`int`, *optional*, defaults to `8000`):
Expand Down Expand Up @@ -261,7 +263,11 @@ class ScriptArguments:
)
data_parallel_size: int = field(
default=1,
metadata={"help": "Number of data parallel workers to use."},
metadata={
"help": "Number of data parallel workers to use. For dense models, keep this at 1. Starting from vLLM "
"`0.14.0`, setting this above `1` for dense models is no longer supported/useful and will error out (see "
"vLLM PR #30739)."
},
)
host: str = field(
default="0.0.0.0",
Expand Down
Loading