diff --git a/docs/source/vllm_integration.md b/docs/source/vllm_integration.md index 82261e6dd8..0543bd4c13 100644 --- a/docs/source/vllm_integration.md +++ b/docs/source/vllm_integration.md @@ -3,7 +3,7 @@ This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood. > [!WARNING] -> TRL currently only supports vLLM versions from `0.10.2` to `0.13.0`. Please ensure you have a version in this range installed to avoid compatibility issues. +> TRL currently only supports vLLM versions from `0.10.2` to `0.14.1`. Please ensure you have a version in this range installed to avoid compatibility issues. > [!TIP] > The following trainers currently support generation with vLLM: @@ -31,12 +31,12 @@ pip install "trl[vllm]" Then run the server on specific GPUs (e.g., GPUs 0-3): ```sh -CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 2 --data-parallel-size 2 +CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 4 ``` Once the server is running, you can use it to generate completions for training. In the example below, we are using the different supported trainers using the vLLM server for generation. The `--tensor-parallel-size` and `--data-parallel-size` arguments control how the model and data are sharded across GPUs. -In this example, we are sharding two copies of the model across 4 GPUs. Increasing data parallelism increases throughput, while increasing tensor parallelism allows for serving larger models. Then, run the training script on different GPUs (e.g., GPUs 4-7) by passing `use_vllm=True` in the training arguments as follows: +In this example, we shard one model across 4 GPUs with tensor parallelism. Then, run the training script on different GPUs (e.g., GPUs 4-7) by passing `use_vllm=True` in the training arguments as follows: Sample of a simple `train.py` script: @@ -166,19 +166,15 @@ If you've ever done autoregressive decoder training, you know all the input toke When you run for example ```sh -CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 1 --data-parallel-size 4 +CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 4 ``` -the following happens: +1. vLLM first spawns multiple workers to handle incoming requests in parallel. The number of workers is determined by multiplying the `--tensor-parallel-size` and `--data-parallel-size` values. In this example, it spawns 4 workers (4 × 1). +Each worker operates independently and processes a chunk of the incoming requests — which are basically the prompts sent to the server for generation. -![vllm](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/vllm-doc.png) +2. Once the incoming requests (prompts) are distributed across the workers, the model starts generating completions. Internally, the model’s weights are split across multiple GPUs based on the `--tensor-parallel-size` argument — this is how tensor parallelism is handled. -1. vLLM first spawns multiple workers to handle incoming requests in parallel. The number of workers is determined by multiplying the `--tensor-parallel-size` and `--data-parallel-size` values. In this example, it spawns 4 workers (1 × 4). -Each worker operates independently and processes a chunk of the incoming requests — which are basically the prompts sent to the server for generation. A key point to understand is that these 4 workers are running in parallel, and each one is responsible for handling a subset of the total incoming load. - -2. Once the incoming requests (prompts) are distributed across the workers, the model starts generating completions. Internally, the model’s weights are split across multiple GPUs based on the `--tensor-parallel-size` argument — this is how tensor parallelism is handled. Meanwhile, data parallelism (controlled by `--data-parallel-size`) ensures that different sets of requests are processed independently across the workers. In short: tensor parallelism splits the model across GPUs, and data parallelism splits the batch of requests across different model replicas. - -3. Although the GPUs process requests independently and in parallel, they still need to communicate with each other. Remember that each GPU handles only a slice of the incoming prompts (for example, with 4 GPUs and 8 prompts using `--data-parallel-size=4`, each GPU processes 2 prompts). +3. Although the GPUs process requests independently and in parallel, they still need to communicate with each other. Remember that each GPU handles only a slice of the incoming prompts (for example, with 4 GPUs and 8 prompts using `--tensor-parallel-size=4`, each GPU participates in serving the full model). This GPU-to-GPU communication is managed efficiently by NVIDIA’s NCCL library. The communication mainly ensures that each GPU gets its correct portion of the incoming requests — it’s lightweight and doesn’t interfere with generation itself. Separately, the number of completions to generate per prompt is controlled by the `num_generations` setting in the GRPO config. For instance, if you set `num_generations=2` (like in the picture above), each prompt will have 2 completions. So, with 8 prompts and `num_generations=2`, you would end up with 16 completions total — regardless of the number of GPUs or parallelism settings. @@ -224,7 +220,9 @@ options: --tensor_parallel_size TENSOR_PARALLEL_SIZE, --tensor-parallel-size TENSOR_PARALLEL_SIZE Number of tensor parallel workers to use. (default: 1) --data_parallel_size DATA_PARALLEL_SIZE, --data-parallel-size DATA_PARALLEL_SIZE - Number of data parallel workers to use. (default: 1) + Number of data parallel workers to use. For dense models, keep this at 1. Starting from vLLM `0.14.0`, setting + this above `1` for dense models is no longer supported/useful and will error out (see vLLM PR #30739). + (default: 1) --host HOST Host address to run the server on. (default: 0.0.0.0) --port PORT Port to run the server on. (default: 8000) --gpu_memory_utilization GPU_MEMORY_UTILIZATION, --gpu-memory-utilization GPU_MEMORY_UTILIZATION @@ -259,20 +257,8 @@ options: ![tp dp throughput 8 gpus](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tp_dp_throughput_8_gpus.png) ![tp dp throughput 4 gpus](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tp_dp_throughput_4_gpus.png) -First and foremost, always remember that the optimal setup depends on: - -- The model size -- The number of GPUs you have -- The GPU memory size -- The batch size you are using -- The number of requests you are sending to the server (prompts) -- The `max_model_len` you are using (this is the max length of the input sequence that the model can process, a.k.a. the context window size) -- The number of completions you are generating for each request (`num_generations`) - -Given these factors, our experiments on the Qwen model family (3B, 7B, 14B, 32B) using 8 H100 GPUs show that: - -- For reasonable-sized models (3B–14B) and a moderate context window (`max_len < 8k`), using full capacity for data parallelism gives better throughput. The setup `(tp=1, dp=8)` yields the best results. -- For larger models (32B) and longer context windows (`max_len > 8k`), a smaller DP size combined with some model-side parallelism performs better. For example, `(tp=2, dp=4)` is a good setup for 32B models with a larger context window. +> [!WARNING] +> The benchmark plots above were collected with older vLLM versions. Starting with [vLLM PR #30739](https://github.com/vllm-project/vllm/pull/30739) (released in `0.14.0`), offline data parallel scaling for non-MoE (dense) models is no longer supported. To follow the latest recommendations, do not scale DP for non-MoE models. ### vLLM with Transformers Backend diff --git a/pyproject.toml b/pyproject.toml index df64506c0b..c8f0c86195 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,7 +83,7 @@ test = [ "pytest" ] vllm = [ - "vllm>=0.10.2,<0.14.0", + "vllm>=0.10.2,<=0.14.1", "fastapi", "pydantic", "requests", diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py index 91ba58d208..7c14af14e9 100644 --- a/tests/test_vllm_client_server.py +++ b/tests/test_vllm_client_server.py @@ -17,6 +17,7 @@ from types import SimpleNamespace import pytest +from packaging.version import Version from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.testing_utils import torch_device @@ -35,8 +36,13 @@ if is_vllm_available(): + import vllm from vllm import LLM, SamplingParams + _is_vllm_ge_014 = Version(vllm.__version__) >= Version("0.14.0") +else: + _is_vllm_ge_014 = False + class TestChunkList(TrlTestCase): def test_even_split(self): @@ -530,6 +536,26 @@ def multiply(a: int, b: int) -> int: decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0]) assert "Multiplies two integers." in decoded_prompt + def test_generate_with_params(self): + prompts = ["Hello, AI!", "Tell me a joke"] + completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[ + "completion_ids" + ] + + # Check that the output is a list + assert isinstance(completion_ids, list) + + # Check that the number of generated sequences is 2 times the number of prompts + assert len(completion_ids) == 2 * len(prompts) + + # Check that the generated sequences are lists of integers + for seq in completion_ids: + assert all(isinstance(tok, int) for tok in seq) + + # Check that the length of the generated sequences is less than or equal to 32 + for seq in completion_ids: + assert len(seq) <= 32 + def test_update_model_params(self): model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device) self.client.update_model_params(model) @@ -549,6 +575,10 @@ def teardown_class(cls): @pytest.mark.slow +@pytest.mark.skipif( + _is_vllm_ge_014, + reason="Skipping DP server test for vLLM>=0.14.0 (PR vllm#30739: DP for non-MoE/dense models no longer supported).", +) @require_3_accelerators @require_vllm class TestVLLMClientServerDP(TrlTestCase): @@ -635,6 +665,26 @@ def multiply(a: int, b: int) -> int: decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0]) assert "Multiplies two integers." in decoded_prompt + def test_generate_with_params(self): + prompts = ["Hello, AI!", "Tell me a joke"] + completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[ + "completion_ids" + ] + + # Check that the output is a list + assert isinstance(completion_ids, list) + + # Check that the number of generated sequences is 2 times the number of prompts + assert len(completion_ids) == 2 * len(prompts) + + # Check that the generated sequences are lists of integers + for seq in completion_ids: + assert all(isinstance(tok, int) for tok in seq) + + # Check that the length of the generated sequences is less than or equal to 32 + for seq in completion_ids: + assert len(seq) <= 32 + def test_update_model_params(self): model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device) self.client.update_model_params(model) diff --git a/trl/_compat.py b/trl/_compat.py index 779a6e40f8..f42f9a081e 100644 --- a/trl/_compat.py +++ b/trl/_compat.py @@ -89,7 +89,7 @@ def _patch_vllm_disabled_tqdm() -> None: - Bug introduced in https://github.com/vllm-project/vllm/pull/52 - Fixed in https://github.com/vllm-project/vllm/pull/28471 (released in v0.11.1) - - Since TRL currently supports vLLM v0.10.2-0.13.0, we patch it here + - Since TRL currently supports vLLM v0.10.2-0.14.1, we patch it here - This can be removed when TRL requires vLLM>=0.11.1 """ if _is_package_version_below("vllm", "0.11.1"): diff --git a/trl/import_utils.py b/trl/import_utils.py index 23cbeb93c3..66af33f168 100644 --- a/trl/import_utils.py +++ b/trl/import_utils.py @@ -109,9 +109,9 @@ def is_uvicorn_available() -> bool: def is_vllm_available() -> bool: _vllm_available, _vllm_version = _is_package_available("vllm", return_version=True) if _vllm_available: - if not (Version("0.10.2") <= Version(_vllm_version) <= Version("0.13.0")): + if not (Version("0.10.2") <= Version(_vllm_version) <= Version("0.14.1")): warnings.warn( - f"TRL currently supports vLLM versions from 0.10.2 to 0.13.0. You have version {_vllm_version} " + f"TRL currently supports vLLM versions from 0.10.2 to 0.14.1. You have version {_vllm_version} " "installed. We recommend installing a supported version to avoid compatibility issues.", stacklevel=2, ) diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 61f2a99b1e..fed2af86bf 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -211,7 +211,9 @@ class ScriptArguments: tensor_parallel_size (`int`, *optional*, defaults to `1`): Number of tensor parallel workers to use. data_parallel_size (`int`, *optional*, defaults to `1`): - Number of data parallel workers to use. + Number of data parallel workers to use. For dense models, keep this at 1. Starting from vLLM `0.14.0`, + setting this above `1` for dense models is no longer supported/useful and will error out (see vLLM PR + #30739). host (`str`, *optional*, defaults to `"0.0.0.0"`): Host address to run the server on. port (`int`, *optional*, defaults to `8000`): @@ -261,7 +263,11 @@ class ScriptArguments: ) data_parallel_size: int = field( default=1, - metadata={"help": "Number of data parallel workers to use."}, + metadata={ + "help": "Number of data parallel workers to use. For dense models, keep this at 1. Starting from vLLM " + "`0.14.0`, setting this above `1` for dense models is no longer supported/useful and will error out (see " + "vLLM PR #30739)." + }, ) host: str = field( default="0.0.0.0",