Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 0 additions & 30 deletions tests/experimental/test_prm_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def test_tokenize_row_no_truncation(self):
tokenizer=self.tokenizer,
step_separator="\n",
max_length=None,
max_prompt_length=None,
max_completion_length=None,
train_on_last_step_only=False,
is_eval=False,
Expand All @@ -160,7 +159,6 @@ def test_tokenize_row_train_on_last_step_only(self):
tokenizer=self.tokenizer,
step_separator="\n",
max_length=None,
max_prompt_length=None,
max_completion_length=None,
train_on_last_step_only=True,
is_eval=False,
Expand All @@ -171,31 +169,6 @@ def test_tokenize_row_train_on_last_step_only(self):
"labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0],
}

def test_tokenize_row_prompt_truncation(self):
# Define the input features
features = {
"prompt": "Which number is larger, 9.8 or 9.11?",
"completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."],
"labels": [True, False],
}

# Call the method with truncation on the completion
result = PRMTrainer.tokenize_row(
features=features,
tokenizer=self.tokenizer,
step_separator="\n",
max_length=None,
max_prompt_length=3,
max_completion_length=None,
train_on_last_step_only=False,
is_eval=False,
)

assert result == {
"input_ids": [6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030],
"labels": [-100, -100, -100, -100, -100, -100, 1, -100, -100, -100, 0],
}

def test_tokenize_row_completion_truncation(self):
# Define the input features
features = {
Expand All @@ -210,7 +183,6 @@ def test_tokenize_row_completion_truncation(self):
tokenizer=self.tokenizer,
step_separator="\n",
max_length=None,
max_prompt_length=None,
max_completion_length=6,
train_on_last_step_only=False,
is_eval=False,
Expand All @@ -235,7 +207,6 @@ def test_tokenize_row_prompt_completion_truncation(self):
tokenizer=self.tokenizer,
step_separator="\n",
max_length=9,
max_prompt_length=None,
max_completion_length=None,
train_on_last_step_only=False,
is_eval=False,
Expand All @@ -260,7 +231,6 @@ def test_tokenize_row_multi_token_separator(self):
tokenizer=self.tokenizer,
step_separator="\n\n",
max_length=None,
max_prompt_length=None,
max_completion_length=None,
train_on_last_step_only=False,
is_eval=False,
Expand Down
6 changes: 0 additions & 6 deletions trl/experimental/prm/prm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ class PRMConfig(TrainingArguments):
Parameters:
max_length (`int` or `None`, *optional*, defaults to `1024`):
Maximum length of the sequences (prompt + completion) used for truncation.
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt used for truncation.
max_completion_length (`int`, *optional*):
Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
disable_dropout (`bool`, *optional*, defaults to `True`):
Expand Down Expand Up @@ -80,10 +78,6 @@ class PRMConfig(TrainingArguments):
default=1024,
metadata={"help": "Maximum length of the sequences (prompt + completion) used for truncation."},
)
max_prompt_length: int | None = field(
default=512,
metadata={"help": "Maximum length of the prompt used for truncation."},
)
max_completion_length: int | None = field(
default=None,
metadata={
Expand Down
8 changes: 1 addition & 7 deletions trl/experimental/prm/prm_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ def __init__(
"tokenizer": processing_class,
"step_separator": args.step_separator,
"max_length": args.max_length,
"max_prompt_length": args.max_prompt_length,
"max_completion_length": args.max_completion_length,
"train_on_last_step_only": args.train_on_last_step_only,
}
Expand Down Expand Up @@ -249,7 +248,6 @@ def tokenize_row(
tokenizer,
step_separator,
max_length,
max_prompt_length,
max_completion_length,
train_on_last_step_only,
is_eval,
Expand All @@ -266,8 +264,6 @@ def tokenize_row(
Separator between steps in the completion.
max_length (`int` or `None`):
Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated.
max_prompt_length (`int` or `None`):
Maximum length of the prompt. If `None`, the prompt is not truncated.
max_completion_length (`int` or `None`):
Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
train_on_last_step_only (`bool`):
Expand Down Expand Up @@ -324,9 +320,7 @@ def tokenize_row(
if tokenizer.bos_token_id is not None:
prompt_ids = [tokenizer.bos_token_id] + prompt_ids

# Truncate prompt and completion sequences
if max_prompt_length is not None:
prompt_ids = prompt_ids[-max_prompt_length:]
# Truncate completion sequences
if max_completion_length is not None:
completion_ids = completion_ids[:max_completion_length]
labels = labels[:max_completion_length]
Expand Down
Loading