diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 40af62a57f6..115937af2d2 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -147,7 +147,7 @@ class GRPOConfig(TrainingArguments): Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set to `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when launching the vLLM server via the `--vllm_gpu_memory_utilization` flag. - vllm_max_model_length (`int`, *optional*, defaults to `None`): + vllm_max_model_length (`int`, *optional*): Context window for vLLM. Set it to at least the maximum prompt length in the dataset plus `max_completion_length`; if omitted, it is inferred from the model config. vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`): diff --git a/trl/trainer/rloo_config.py b/trl/trainer/rloo_config.py index 921afa697a1..fd81338ce2a 100644 --- a/trl/trainer/rloo_config.py +++ b/trl/trainer/rloo_config.py @@ -144,6 +144,9 @@ class RLOOConfig(TrainingArguments): Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set to `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when launching the vLLM server via the `--vllm_gpu_memory_utilization` flag. + vllm_max_model_length (`int`, *optional*): + Context window for vLLM. Set it to at least the maximum prompt length in the dataset plus + `max_completion_length`; if omitted, it is inferred from the model config. vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`): Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set to `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when @@ -458,6 +461,13 @@ class RLOOConfig(TrainingArguments): "launching the vLLM server via the `--vllm_gpu_memory_utilization` flag." }, ) + vllm_max_model_length: int | None = field( + default=None, + metadata={ + "help": "Context window for vLLM. Set it to at least the maximum prompt length in the dataset plus " + "`max_completion_length`; if omitted, it is inferred from the model config." + }, + ) vllm_tensor_parallel_size: int = field( default=1, metadata={ diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index bf4e2a97648..fedb8d6081a 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -508,10 +508,6 @@ def __init__( # Ensure distributed rendezvous variables are set without colliding across concurrent runs ensure_master_addr_port() - if self.max_prompt_length is not None and self.max_completion_length is not None: - max_model_len = self.max_prompt_length + self.max_completion_length - else: - max_model_len = None vllm_quantization = None if is_bitsandbytes_available(): for _, module in model.named_modules(): @@ -527,7 +523,7 @@ def __init__( max_num_seqs=self.args.per_device_train_batch_size * self.vllm_tensor_parallel_size * self.args.steps_per_generation, - max_model_len=max_model_len, + max_model_len=self.args.vllm_max_model_length, distributed_executor_backend="external_launcher", # Feed identical seed for tp groups to ensure sampling results are the same across workers seed=self.accelerator.process_index // self.vllm_tensor_parallel_size,