Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions examples/scripts/bco.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
--output_dir bco-aligned-model \
--logging_first_step \
--max_length 2048 \
--max_prompt_length 1536 \
--max_completion_length 1024 \
--no_remove_unused_columns \
--warmup_ratio 0.1
Expand All @@ -63,7 +62,6 @@
--logging_first_step \
--warmup_ratio 0.1 \
--max_length 2048 \
--max_prompt_length 1536 \
--max_completion_length 1024 \
--no_remove_unused_columns \
--warmup_ratio 0.1 \
Expand Down
1 change: 0 additions & 1 deletion tests/experimental/test_bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ def test_tokenize_and_process_tokens(self):
"tokenizer": trainer.processing_class,
"max_length": trainer.max_length,
"truncation_mode": trainer.truncation_mode,
"max_prompt_length": trainer.max_prompt_length,
}
processed_dataset = tokenized_dataset.map(_process_tokens, fn_kwargs=fn_kwargs)
assert processed_dataset["prompt"][:] == dataset["prompt"][:]
Expand Down
9 changes: 0 additions & 9 deletions trl/experimental/bco/bco_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class BCOConfig(TrainingArguments):
max_length (`int` or `None`, *optional*, defaults to `1024`):
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
to use the default data collator.
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt. This argument is required if you want to use the default data collator.
max_completion_length (`int`, *optional*):
Maximum length of the completion. This argument is required if you want to use the default data collator
and your model is an encoder-decoder.
Expand Down Expand Up @@ -114,13 +112,6 @@ class BCOConfig(TrainingArguments):
"This argument is required if you want to use the default data collator."
},
)
max_prompt_length: int | None = field(
default=512,
metadata={
"help": "Maximum length of the prompt. "
"This argument is required if you want to use the default data collator."
},
)
max_completion_length: int | None = field(
default=None,
metadata={
Expand Down
30 changes: 3 additions & 27 deletions trl/experimental/bco/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,20 +205,10 @@ def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = None, **
if eos_token_id != all_tokens["answer_input_ids"][-1]:
max_length -= 1

# if combined sequence is too long (> max_length - 1 for BOS token - 1 for EOS), truncate the prompt
if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length:
for k in ["prompt_input_ids", "prompt_attention_mask"]:
if kwargs["truncation_mode"] == "keep_start":
all_tokens[k] = all_tokens[k][: kwargs["max_prompt_length"]]
elif kwargs["truncation_mode"] == "keep_end":
all_tokens[k] = all_tokens[k][-kwargs["max_prompt_length"] :]
else:
raise ValueError(f"Unknown truncation mode: {kwargs['truncation_mode']}")

# if that's still too long, truncate the response
# if combined sequence is too long (> max_length - 1 for BOS token - 1 for EOS), truncate the response
if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length:
for k in ["answer_input_ids", "answer_attention_mask"]:
all_tokens[k] = all_tokens[k][: max_length - kwargs["max_prompt_length"]]
all_tokens[k] = all_tokens[k][: max_length - len(all_tokens["prompt_input_ids"])]

# all input_ids and attention mask as is. We then check if we need to add BOS/EOS tokens
batch[f"{kwargs['prefix']}prompt_input_ids"] = all_tokens["prompt_input_ids"]
Expand Down Expand Up @@ -262,9 +252,7 @@ def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = None, **
completion_tokens = kwargs["tokenizer"](
completion, truncation=True, max_length=kwargs["max_completion_length"], add_special_tokens=True
)
prompt_tokens = kwargs["tokenizer"](
prompt, truncation=True, max_length=kwargs["max_prompt_length"], add_special_tokens=True
)
prompt_tokens = kwargs["tokenizer"](prompt, add_special_tokens=True)

batch[f"{kwargs['prefix']}prompt_input_ids"] = prompt_tokens["input_ids"]
batch[f"{kwargs['prefix']}prompt_attention_mask"] = prompt_tokens["attention_mask"]
Expand Down Expand Up @@ -501,15 +489,6 @@ def make_inputs_require_grad(module, input, output):
if args.max_length is not None:
max_length = args.max_length

if args.max_prompt_length is None:
logger.warning(
"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the `BCOConfig`. "
"It will be set to `128` by default, but you should do it yourself in the future.",
)
max_prompt_length = 128
if args.max_prompt_length is not None:
max_prompt_length = args.max_prompt_length

max_completion_length = None
if args.max_completion_length is None and self.is_encoder_decoder:
logger.warning(
Expand Down Expand Up @@ -546,7 +525,6 @@ def make_inputs_require_grad(module, input, output):

self.max_length = max_length
self.generate_during_eval = args.generate_during_eval
self.max_prompt_length = max_prompt_length
self.truncation_mode = args.truncation_mode
self.max_completion_length = max_completion_length
self.precompute_ref_log_probs = args.precompute_ref_log_probs
Expand Down Expand Up @@ -619,7 +597,6 @@ def make_inputs_require_grad(module, input, output):
"tokenizer": processing_class,
"max_length": self.max_length,
"truncation_mode": self.truncation_mode,
"max_prompt_length": self.max_prompt_length,
"max_completion_length": self.max_completion_length,
}
train_dataset = train_dataset.map(
Expand All @@ -646,7 +623,6 @@ def make_inputs_require_grad(module, input, output):
"tokenizer": processing_class,
"max_length": self.max_length,
"truncation_mode": self.truncation_mode,
"max_prompt_length": self.max_prompt_length,
"max_completion_length": self.max_completion_length,
}
eval_dataset = eval_dataset.map(
Expand Down
Loading