Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ class GRPOConfig(TrainingArguments):

> Deprecated arguments

max_prompt_length (`bool`, *optional*):
max_prompt_length:

<Deprecated version="0.26.0">

Expand Down
37 changes: 29 additions & 8 deletions trl/trainer/rloo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ class RLOOConfig(TrainingArguments):
remove_unused_columns (`bool`, *optional*, defaults to `False`):
Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
num_generations (`int`, *optional*, defaults to `2`):
Number of generations per prompt to sample. The effective batch size (num_processes * per_device_batch_size
* gradient_accumulation_steps) must be evenly divisible by this value.
Expand Down Expand Up @@ -204,6 +202,17 @@ class RLOOConfig(TrainingArguments):
log_unique_prompts (`bool`, *optional*, defaults to `False`):
Whether to log unique prompts. If `True`, only unique prompts are logged. If `False`, all prompts are
logged.

> Deprecated arguments

max_prompt_length:

<Deprecated version="0.26.0">

Parameter `max_prompt_length` is deprecated and will be removed in version 0.28.0. You should instead
filter your dataset before training to ensure that prompts do not exceed your desired length.

</Deprecated>
"""

_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"]
Expand Down Expand Up @@ -271,12 +280,6 @@ class RLOOConfig(TrainingArguments):
"that requires any column other than 'prompts' and 'completions', you should keep this to `False`."
},
)
max_prompt_length: int | None = field(
default=512,
metadata={
"help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left."
},
)
num_generations: int | None = field(
default=2,
metadata={
Expand Down Expand Up @@ -569,6 +572,15 @@ class RLOOConfig(TrainingArguments):
},
)

# Deprecated arguments
max_prompt_length: int | None = field(
default=None,
metadata={
"help": "Deprecated, filter your dataset before training to ensure that prompts do not exceed your "
"desired length."
},
)

def __post_init__(self):
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16
if self.top_k is None:
Expand Down Expand Up @@ -628,3 +640,12 @@ def __post_init__(self):
"RLOO requires at least 2 generations per prompt to calculate the advantages. You provided "
f"{self.num_generations}, which is less than the minimum required."
)

if self.max_prompt_length is not None:
warnings.warn(
"The `max_prompt_length` argument is deprecated and will be removed in version 0.28.0. You should "
"instead filter your dataset before training to ensure that prompts do not exceed your desired "
"length.",
FutureWarning,
stacklevel=2,
)
27 changes: 7 additions & 20 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,6 @@ def __init__(
self.reward_processing_classes = reward_processing_classes

# Training arguments
self.max_prompt_length = args.max_prompt_length
self.max_completion_length = args.max_completion_length
self.num_generations = args.num_generations
self.num_generations_eval = args.num_generations_eval or self.num_generations
Expand Down Expand Up @@ -1044,7 +1043,6 @@ def _generate_single_turn(self, prompts: list):
"top_k": self.top_k,
"min_p": 0.0 if self.min_p is None else self.min_p,
"max_tokens": self.max_completion_length,
"truncate_prompt_tokens": self.max_prompt_length,
"guided_decoding_regex": self.guided_decoding_regex,
"generation_kwargs": self.args.generation_kwargs,
}
Expand Down Expand Up @@ -1091,7 +1089,6 @@ def _generate_single_turn(self, prompts: list):
"top_k": self.top_k,
"min_p": 0.0 if self.min_p is None else self.min_p,
"max_tokens": self.max_completion_length,
"truncate_prompt_tokens": self.max_prompt_length,
"guided_decoding": guided_decoding,
}
if self.args.generation_kwargs is not None:
Expand Down Expand Up @@ -1135,22 +1132,16 @@ def _generate_single_turn(self, prompts: list):
self.llm.sleep(level=2)

elif self.use_transformers_paged:
processor_kwargs = {
"max_length": self.max_prompt_length,
"truncation": True,
"add_special_tokens": False,
}
if is_conversational({"prompt": prompts[0]}):
processor_outputs = self.processing_class.apply_chat_template(
conversation=prompts,
**processor_kwargs,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
**self.chat_template_kwargs,
)
else:
processor_outputs = self.processing_class(text=prompts, **processor_kwargs)
processor_outputs = self.processing_class(text=prompts)

with (
profiling_context(self, "transformers.generate_batch"),
Expand All @@ -1176,25 +1167,21 @@ def _generate_single_turn(self, prompts: list):

else:
# Regular generation path
processor_kwargs = {
"return_tensors": "pt",
"padding": True,
"padding_side": "left",
"max_length": self.max_prompt_length,
"truncation": True,
"add_special_tokens": False,
}
if is_conversational({"prompt": prompts[0]}):
generate_inputs = self.processing_class.apply_chat_template(
conversation=prompts,
**processor_kwargs,
add_generation_prompt=True,
tokenize=True,
padding=True,
padding_side="left",
return_tensors="pt",
return_dict=True,
**self.chat_template_kwargs,
)
else:
generate_inputs = self.processing_class(text=prompts, **processor_kwargs)
generate_inputs = self.processing_class(
text=prompts, padding=True, padding_side="left", return_tensors="pt"
)
generate_inputs = super()._prepare_inputs(generate_inputs)

with (
Expand Down
Loading