diff --git a/examples/scripts/bco.py b/examples/scripts/bco.py index 2f81ff4363..6ad474914b 100644 --- a/examples/scripts/bco.py +++ b/examples/scripts/bco.py @@ -42,7 +42,6 @@ --output_dir bco-aligned-model \ --logging_first_step \ --max_length 2048 \ - --max_prompt_length 1536 \ --max_completion_length 1024 \ --no_remove_unused_columns \ --warmup_ratio 0.1 @@ -63,7 +62,6 @@ --logging_first_step \ --warmup_ratio 0.1 \ --max_length 2048 \ - --max_prompt_length 1536 \ --max_completion_length 1024 \ --no_remove_unused_columns \ --warmup_ratio 0.1 \ diff --git a/tests/experimental/test_bco_trainer.py b/tests/experimental/test_bco_trainer.py index 30880e0c57..34f42967c3 100644 --- a/tests/experimental/test_bco_trainer.py +++ b/tests/experimental/test_bco_trainer.py @@ -211,7 +211,6 @@ def test_tokenize_and_process_tokens(self): "tokenizer": trainer.processing_class, "max_length": trainer.max_length, "truncation_mode": trainer.truncation_mode, - "max_prompt_length": trainer.max_prompt_length, } processed_dataset = tokenized_dataset.map(_process_tokens, fn_kwargs=fn_kwargs) assert processed_dataset["prompt"][:] == dataset["prompt"][:] diff --git a/trl/experimental/bco/bco_config.py b/trl/experimental/bco/bco_config.py index 5e4783aa75..8a14a816e3 100644 --- a/trl/experimental/bco/bco_config.py +++ b/trl/experimental/bco/bco_config.py @@ -37,8 +37,6 @@ class BCOConfig(TrainingArguments): max_length (`int` or `None`, *optional*, defaults to `1024`): Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want to use the default data collator. - max_prompt_length (`int` or `None`, *optional*, defaults to `512`): - Maximum length of the prompt. This argument is required if you want to use the default data collator. max_completion_length (`int`, *optional*): Maximum length of the completion. This argument is required if you want to use the default data collator and your model is an encoder-decoder. @@ -114,13 +112,6 @@ class BCOConfig(TrainingArguments): "This argument is required if you want to use the default data collator." }, ) - max_prompt_length: int | None = field( - default=512, - metadata={ - "help": "Maximum length of the prompt. " - "This argument is required if you want to use the default data collator." - }, - ) max_completion_length: int | None = field( default=None, metadata={ diff --git a/trl/experimental/bco/bco_trainer.py b/trl/experimental/bco/bco_trainer.py index d769f49ec4..e124404ecd 100644 --- a/trl/experimental/bco/bco_trainer.py +++ b/trl/experimental/bco/bco_trainer.py @@ -205,20 +205,10 @@ def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = None, ** if eos_token_id != all_tokens["answer_input_ids"][-1]: max_length -= 1 - # if combined sequence is too long (> max_length - 1 for BOS token - 1 for EOS), truncate the prompt - if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length: - for k in ["prompt_input_ids", "prompt_attention_mask"]: - if kwargs["truncation_mode"] == "keep_start": - all_tokens[k] = all_tokens[k][: kwargs["max_prompt_length"]] - elif kwargs["truncation_mode"] == "keep_end": - all_tokens[k] = all_tokens[k][-kwargs["max_prompt_length"] :] - else: - raise ValueError(f"Unknown truncation mode: {kwargs['truncation_mode']}") - - # if that's still too long, truncate the response + # if combined sequence is too long (> max_length - 1 for BOS token - 1 for EOS), truncate the response if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length: for k in ["answer_input_ids", "answer_attention_mask"]: - all_tokens[k] = all_tokens[k][: max_length - kwargs["max_prompt_length"]] + all_tokens[k] = all_tokens[k][: max_length - len(all_tokens["prompt_input_ids"])] # all input_ids and attention mask as is. We then check if we need to add BOS/EOS tokens batch[f"{kwargs['prefix']}prompt_input_ids"] = all_tokens["prompt_input_ids"] @@ -262,9 +252,7 @@ def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = None, ** completion_tokens = kwargs["tokenizer"]( completion, truncation=True, max_length=kwargs["max_completion_length"], add_special_tokens=True ) - prompt_tokens = kwargs["tokenizer"]( - prompt, truncation=True, max_length=kwargs["max_prompt_length"], add_special_tokens=True - ) + prompt_tokens = kwargs["tokenizer"](prompt, add_special_tokens=True) batch[f"{kwargs['prefix']}prompt_input_ids"] = prompt_tokens["input_ids"] batch[f"{kwargs['prefix']}prompt_attention_mask"] = prompt_tokens["attention_mask"] @@ -501,15 +489,6 @@ def make_inputs_require_grad(module, input, output): if args.max_length is not None: max_length = args.max_length - if args.max_prompt_length is None: - logger.warning( - "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the `BCOConfig`. " - "It will be set to `128` by default, but you should do it yourself in the future.", - ) - max_prompt_length = 128 - if args.max_prompt_length is not None: - max_prompt_length = args.max_prompt_length - max_completion_length = None if args.max_completion_length is None and self.is_encoder_decoder: logger.warning( @@ -546,7 +525,6 @@ def make_inputs_require_grad(module, input, output): self.max_length = max_length self.generate_during_eval = args.generate_during_eval - self.max_prompt_length = max_prompt_length self.truncation_mode = args.truncation_mode self.max_completion_length = max_completion_length self.precompute_ref_log_probs = args.precompute_ref_log_probs @@ -619,7 +597,6 @@ def make_inputs_require_grad(module, input, output): "tokenizer": processing_class, "max_length": self.max_length, "truncation_mode": self.truncation_mode, - "max_prompt_length": self.max_prompt_length, "max_completion_length": self.max_completion_length, } train_dataset = train_dataset.map( @@ -646,7 +623,6 @@ def make_inputs_require_grad(module, input, output): "tokenizer": processing_class, "max_length": self.max_length, "truncation_mode": self.truncation_mode, - "max_prompt_length": self.max_prompt_length, "max_completion_length": self.max_completion_length, } eval_dataset = eval_dataset.map(