diff --git a/tests/experimental/test_kto_trainer.py b/tests/experimental/test_kto_trainer.py index 7108de87c21..0f6ed1945ce 100644 --- a/tests/experimental/test_kto_trainer.py +++ b/tests/experimental/test_kto_trainer.py @@ -155,9 +155,7 @@ def test_tokenize_and_process_tokens(self): "prefix": "", "tokenizer": trainer.processing_class, "max_length": trainer.max_length, - "truncation_mode": trainer.truncation_mode, "label_pad_token_id": trainer.label_pad_token_id, - "max_prompt_length": trainer.max_prompt_length, } processed_dataset = tokenized_dataset.map(_process_tokens, fn_kwargs=fn_kwargs, num_proc=2) assert processed_dataset["prompt"][:] == train_dataset["prompt"][:] diff --git a/trl/experimental/kto/kto_config.py b/trl/experimental/kto/kto_config.py index b67b6416a1d..e23f27979a7 100644 --- a/trl/experimental/kto/kto_config.py +++ b/trl/experimental/kto/kto_config.py @@ -36,8 +36,6 @@ class KTOConfig(TrainingArguments): max_length (`int` or `None`, *optional*, defaults to `1024`): Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want to use the default data collator. - max_prompt_length (`int` or `None`, *optional*, defaults to `512`): - Maximum length of the prompt. This argument is required if you want to use the default data collator. beta (`float`, *optional*, defaults to `0.1`): Parameter controlling the deviation from the reference model. Higher β means less deviation from the reference model. @@ -56,9 +54,6 @@ class KTOConfig(TrainingArguments): Label pad token id. This argument is required if you want to use the default data collator. padding_value (`int`, *optional*): Padding value to use. If `None`, the padding value of the tokenizer is used. - truncation_mode (`str`, *optional*, defaults to `"keep_end"`): - Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`. - This argument is required if you want to use the default data collator. generate_during_eval (`bool`, *optional*, defaults to `False`): If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during evaluation. @@ -134,13 +129,6 @@ class KTOConfig(TrainingArguments): default=1024, metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."}, ) - max_prompt_length: int | None = field( - default=512, - metadata={ - "help": "Maximum length of the prompt. This argument is required if you want to use the default data " - "collator." - }, - ) beta: float = field( default=0.1, metadata={ @@ -179,13 +167,6 @@ class KTOConfig(TrainingArguments): default=None, metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."}, ) - truncation_mode: str = field( - default="keep_end", - metadata={ - "help": "Truncation mode to use when the prompt is too long.", - "choices": ["keep_end", "keep_start"], - }, - ) generate_during_eval: bool = field( default=False, metadata={ diff --git a/trl/experimental/kto/kto_trainer.py b/trl/experimental/kto/kto_trainer.py index 6a9b7f36aa0..14677ac0eee 100644 --- a/trl/experimental/kto/kto_trainer.py +++ b/trl/experimental/kto/kto_trainer.py @@ -156,8 +156,7 @@ def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = None, ** """Process tokens of a KTO specific dataset. At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt + - completion responses is/are too long. First we truncate the prompt; if we're still too long, we truncate the - completion. + completion responses is/are too long. We truncate from the end (completion) to fit within max_length. We also create the labels for the completion responses, which are of length equal to the sum of the length of the prompt and the completion response, with label_pad_token_id for the prompt tokens. @@ -199,20 +198,13 @@ def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = None, ** if len(all_tokens["answer_input_ids"]) > 0 and eos_token_id != all_tokens["answer_input_ids"][-1]: max_length -= 1 - # if combined sequence is too long (> max_length - 1 for BOS token - 1 for EOS), truncate the prompt - if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length: - for k in ["prompt_input_ids", "prompt_attention_mask"]: - if kwargs["truncation_mode"] == "keep_start": - all_tokens[k] = all_tokens[k][: kwargs["max_prompt_length"]] - elif kwargs["truncation_mode"] == "keep_end": - all_tokens[k] = all_tokens[k][-kwargs["max_prompt_length"] :] - else: - raise ValueError(f"Unknown truncation mode: {kwargs['truncation_mode']}") - - # if that's still too long, truncate the response - if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length: + # if combined sequence is too long, truncate the completion (answer) from the end + prompt_length = len(all_tokens["prompt_input_ids"]) + completion_length = len(all_tokens["answer_input_ids"]) + if prompt_length + completion_length > max_length: + max_completion_length = max_length - prompt_length for k in ["answer_input_ids", "answer_attention_mask"]: - all_tokens[k] = all_tokens[k][: max_length - kwargs["max_prompt_length"]] + all_tokens[k] = all_tokens[k][:max_completion_length] # all input_ids and attention mask as is. We then check if we need to add BOS/EOS tokens batch[f"{kwargs['prefix']}prompt_input_ids"] = all_tokens["prompt_input_ids"] @@ -471,15 +463,6 @@ def make_inputs_require_grad(module, input, output): if args.max_length is not None: max_length = args.max_length - if args.max_prompt_length is None: - logger.warning( - "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init" - " it will be set to `128` by default, but you should do it yourself in the future.", - ) - max_prompt_length = 128 - if args.max_prompt_length is not None: - max_prompt_length = args.max_prompt_length - if data_collator is None: data_collator = DPODataCollatorWithPadding( pad_token_id=processing_class.pad_token_id, @@ -509,8 +492,6 @@ def make_inputs_require_grad(module, input, output): self.generate_during_eval = args.generate_during_eval self.label_pad_token_id = args.label_pad_token_id self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id - self.max_prompt_length = max_prompt_length - self.truncation_mode = args.truncation_mode self.processing_class = processing_class self.precompute_ref_log_probs = args.precompute_ref_log_probs @@ -595,9 +576,7 @@ def make_inputs_require_grad(module, input, output): "prefix": "", "tokenizer": self.processing_class, "max_length": self.max_length, - "truncation_mode": self.truncation_mode, "label_pad_token_id": self.label_pad_token_id, - "max_prompt_length": self.max_prompt_length, } train_dataset = train_dataset.map(