From 5cb6fe05b11ed370bdce600ab01f914c13d73c8f Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 4 Feb 2026 10:42:02 +0100 Subject: [PATCH 1/4] Remove max_prompt_length from PRMConfig --- trl/experimental/prm/prm_config.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/trl/experimental/prm/prm_config.py b/trl/experimental/prm/prm_config.py index 3bbd63cfc30..b76fdc9e3ae 100644 --- a/trl/experimental/prm/prm_config.py +++ b/trl/experimental/prm/prm_config.py @@ -35,8 +35,6 @@ class PRMConfig(TrainingArguments): Parameters: max_length (`int` or `None`, *optional*, defaults to `1024`): Maximum length of the sequences (prompt + completion) used for truncation. - max_prompt_length (`int` or `None`, *optional*, defaults to `512`): - Maximum length of the prompt used for truncation. max_completion_length (`int`, *optional*): Maximum length of the completion used for truncation. The completion is the concatenation of the steps. disable_dropout (`bool`, *optional*, defaults to `True`): @@ -80,10 +78,6 @@ class PRMConfig(TrainingArguments): default=1024, metadata={"help": "Maximum length of the sequences (prompt + completion) used for truncation."}, ) - max_prompt_length: int | None = field( - default=512, - metadata={"help": "Maximum length of the prompt used for truncation."}, - ) max_completion_length: int | None = field( default=None, metadata={ From 647b33d9e65571a2c65ab785a4fb23261ace5068 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 4 Feb 2026 10:43:26 +0100 Subject: [PATCH 2/4] Remove max_prompt_length from PRMTrainer --- trl/experimental/prm/prm_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trl/experimental/prm/prm_trainer.py b/trl/experimental/prm/prm_trainer.py index aa9956441e4..683d61d135e 100644 --- a/trl/experimental/prm/prm_trainer.py +++ b/trl/experimental/prm/prm_trainer.py @@ -190,7 +190,6 @@ def __init__( "tokenizer": processing_class, "step_separator": args.step_separator, "max_length": args.max_length, - "max_prompt_length": args.max_prompt_length, "max_completion_length": args.max_completion_length, "train_on_last_step_only": args.train_on_last_step_only, } From 53a790fe57a86803ed8c1cef46ef822c547ab3fa Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 4 Feb 2026 10:44:14 +0100 Subject: [PATCH 3/4] Remove max_prompt_length from PRMTrainer.tokenize_row --- trl/experimental/prm/prm_trainer.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/trl/experimental/prm/prm_trainer.py b/trl/experimental/prm/prm_trainer.py index 683d61d135e..79bf4bb904a 100644 --- a/trl/experimental/prm/prm_trainer.py +++ b/trl/experimental/prm/prm_trainer.py @@ -248,7 +248,6 @@ def tokenize_row( tokenizer, step_separator, max_length, - max_prompt_length, max_completion_length, train_on_last_step_only, is_eval, @@ -265,8 +264,6 @@ def tokenize_row( Separator between steps in the completion. max_length (`int` or `None`): Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated. - max_prompt_length (`int` or `None`): - Maximum length of the prompt. If `None`, the prompt is not truncated. max_completion_length (`int` or `None`): Maximum length of the completion sequences. If `None`, the completion sequences are not truncated. train_on_last_step_only (`bool`): @@ -323,9 +320,7 @@ def tokenize_row( if tokenizer.bos_token_id is not None: prompt_ids = [tokenizer.bos_token_id] + prompt_ids - # Truncate prompt and completion sequences - if max_prompt_length is not None: - prompt_ids = prompt_ids[-max_prompt_length:] + # Truncate completion sequences if max_completion_length is not None: completion_ids = completion_ids[:max_completion_length] labels = labels[:max_completion_length] From 6d01747fca7edf99e1faf21aeb2bcfb7c8698ddc Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 4 Feb 2026 10:44:33 +0100 Subject: [PATCH 4/4] Update tests --- tests/experimental/test_prm_trainer.py | 30 -------------------------- 1 file changed, 30 deletions(-) diff --git a/tests/experimental/test_prm_trainer.py b/tests/experimental/test_prm_trainer.py index 3bd8f553afb..be36adb39a4 100644 --- a/tests/experimental/test_prm_trainer.py +++ b/tests/experimental/test_prm_trainer.py @@ -136,7 +136,6 @@ def test_tokenize_row_no_truncation(self): tokenizer=self.tokenizer, step_separator="\n", max_length=None, - max_prompt_length=None, max_completion_length=None, train_on_last_step_only=False, is_eval=False, @@ -160,7 +159,6 @@ def test_tokenize_row_train_on_last_step_only(self): tokenizer=self.tokenizer, step_separator="\n", max_length=None, - max_prompt_length=None, max_completion_length=None, train_on_last_step_only=True, is_eval=False, @@ -171,31 +169,6 @@ def test_tokenize_row_train_on_last_step_only(self): "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0], } - def test_tokenize_row_prompt_truncation(self): - # Define the input features - features = { - "prompt": "Which number is larger, 9.8 or 9.11?", - "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."], - "labels": [True, False], - } - - # Call the method with truncation on the completion - result = PRMTrainer.tokenize_row( - features=features, - tokenizer=self.tokenizer, - step_separator="\n", - max_length=None, - max_prompt_length=3, - max_completion_length=None, - train_on_last_step_only=False, - is_eval=False, - ) - - assert result == { - "input_ids": [6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030], - "labels": [-100, -100, -100, -100, -100, -100, 1, -100, -100, -100, 0], - } - def test_tokenize_row_completion_truncation(self): # Define the input features features = { @@ -210,7 +183,6 @@ def test_tokenize_row_completion_truncation(self): tokenizer=self.tokenizer, step_separator="\n", max_length=None, - max_prompt_length=None, max_completion_length=6, train_on_last_step_only=False, is_eval=False, @@ -235,7 +207,6 @@ def test_tokenize_row_prompt_completion_truncation(self): tokenizer=self.tokenizer, step_separator="\n", max_length=9, - max_prompt_length=None, max_completion_length=None, train_on_last_step_only=False, is_eval=False, @@ -260,7 +231,6 @@ def test_tokenize_row_multi_token_separator(self): tokenizer=self.tokenizer, step_separator="\n\n", max_length=None, - max_prompt_length=None, max_completion_length=None, train_on_last_step_only=False, is_eval=False,