diff --git a/docs/source/vllm_integration.md b/docs/source/vllm_integration.md index b2838215d4c..9d3f6beee11 100644 --- a/docs/source/vllm_integration.md +++ b/docs/source/vllm_integration.md @@ -3,7 +3,7 @@ This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood. Let's go! 🔥 > [!WARNING] -> TRL currently only supports vLLM versions `0.10.0`, `0.10.1`, and `0.10.2`. Please ensure you have one of these versions installed to avoid compatibility issues. +> TRL currently only supports vLLM version `0.10.2`. Please ensure you have this version installed to avoid compatibility issues. ## 🚀 How can I use vLLM with TRL to speed up training? diff --git a/setup.cfg b/setup.cfg index f84ba1f950e..cf82100e979 100644 --- a/setup.cfg +++ b/setup.cfg @@ -62,7 +62,7 @@ test = pytest-xdist pytest vllm = - vllm>=0.10.0,<=0.10.2 + vllm==0.10.2 fastapi pydantic requests diff --git a/trl/import_utils.py b/trl/import_utils.py index 0f15a17222c..10709dc549c 100644 --- a/trl/import_utils.py +++ b/trl/import_utils.py @@ -82,13 +82,10 @@ def is_uvicorn_available() -> bool: def is_vllm_available() -> bool: - if _vllm_available and ( - version.parse(_vllm_version) < version.parse("0.10.0") - or version.parse(_vllm_version) > version.parse("0.10.2") - ): + if _vllm_available and version.parse(_vllm_version) != version.parse("0.10.2"): warnings.warn( - "TRL currently only supports vLLM versions `0.10.0`, `0.10.1`, and `0.10.2`. You have version " - f"{_vllm_version} installed. We recommend to install one of these versions to avoid compatibility issues.", + f"TRL currently only supports vLLM version `0.10.2`. You have version {_vllm_version} installed. We " + "recommend to install this version to avoid compatibility issues.", UserWarning, ) return _vllm_available diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 22ab4df9275..4961adbe92d 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -549,6 +549,8 @@ def __init__( max_num_batched_tokens=4096, model_impl=self.args.vllm_model_impl, enable_sleep_mode=self.args.vllm_enable_sleep_mode, + # Important so temperature scaling/logit tweaking affects the TIS log probs + logprobs_mode="processed_logprobs", ) if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=1)