diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 267b64b6e3b..9511f27d5e0 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1259,12 +1259,15 @@ def _generate_single_turn(self, prompts: list): processor_kwargs = { "max_length": self.max_prompt_length, "truncation": True, - "add_generation_prompt": True, "add_special_tokens": False, } if is_conversational({"prompt": prompts[0]}): processor_outputs = self.processing_class.apply_chat_template( - conversation=prompts, **processor_kwargs, tokenize=True, return_dict=True + conversation=prompts, + **processor_kwargs, + add_generation_prompt=True, + tokenize=True, + return_dict=True, ) else: processor_outputs = self.processing_class(text=prompts, **processor_kwargs) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 3ed56624e83..b9fdb3fd5b1 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -1084,13 +1084,13 @@ def _generate_single_turn(self, prompts: list): processor_kwargs = { "max_length": self.max_prompt_length, "truncation": True, - "add_generation_prompt": True, "add_special_tokens": False, } if is_conversational({"prompt": prompts[0]}): processor_outputs = self.processing_class.apply_chat_template( conversation=prompts, **processor_kwargs, + add_generation_prompt=True, tokenize=True, return_dict=True, ) @@ -1133,7 +1133,7 @@ def _generate_single_turn(self, prompts: list): generate_inputs = self.processing_class.apply_chat_template( conversation=prompts, **processor_kwargs, - add_generation_kwargs=True, + add_generation_prompt=True, tokenize=True, return_dict=True, )