Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class GRPOConfig(TrainingArguments):
Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set to
`"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when
launching the vLLM server via the `--vllm_gpu_memory_utilization` flag.
vllm_max_model_length (`int`, *optional*, defaults to `None`):
vllm_max_model_length (`int`, *optional*):
Context window for vLLM. Set it to at least the maximum prompt length in the dataset plus
`max_completion_length`; if omitted, it is inferred from the model config.
vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`):
Expand Down
10 changes: 10 additions & 0 deletions trl/trainer/rloo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ class RLOOConfig(TrainingArguments):
Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set to
`"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when
launching the vLLM server via the `--vllm_gpu_memory_utilization` flag.
vllm_max_model_length (`int`, *optional*):
Context window for vLLM. Set it to at least the maximum prompt length in the dataset plus
`max_completion_length`; if omitted, it is inferred from the model config.
vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`):
Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set to
`"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when
Expand Down Expand Up @@ -458,6 +461,13 @@ class RLOOConfig(TrainingArguments):
"launching the vLLM server via the `--vllm_gpu_memory_utilization` flag."
},
)
vllm_max_model_length: int | None = field(
default=None,
metadata={
"help": "Context window for vLLM. Set it to at least the maximum prompt length in the dataset plus "
"`max_completion_length`; if omitted, it is inferred from the model config."
},
)
vllm_tensor_parallel_size: int = field(
default=1,
metadata={
Expand Down
6 changes: 1 addition & 5 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,10 +508,6 @@ def __init__(
# Ensure distributed rendezvous variables are set without colliding across concurrent runs
ensure_master_addr_port()

if self.max_prompt_length is not None and self.max_completion_length is not None:
max_model_len = self.max_prompt_length + self.max_completion_length
else:
max_model_len = None
vllm_quantization = None
if is_bitsandbytes_available():
for _, module in model.named_modules():
Expand All @@ -527,7 +523,7 @@ def __init__(
max_num_seqs=self.args.per_device_train_batch_size
* self.vllm_tensor_parallel_size
* self.args.steps_per_generation,
max_model_len=max_model_len,
max_model_len=self.args.vllm_max_model_length,
distributed_executor_backend="external_launcher",
# Feed identical seed for tp groups to ensure sampling results are the same across workers
seed=self.accelerator.process_index // self.vllm_tensor_parallel_size,
Expand Down
Loading