Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions trl/experimental/orpo/orpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class ORPOConfig(TrainingArguments):
max_length (`int` or `None`, *optional*, defaults to `1024`):
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
to use the default data collator.
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt. This argument is required if you want to use the default data collator.
max_completion_length (`int`, *optional*):
Maximum length of the completion. This argument is required if you want to use the default data collator
and your model is an encoder-decoder.
Expand Down Expand Up @@ -108,13 +106,6 @@ class ORPOConfig(TrainingArguments):
default=1024,
metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."},
)
max_prompt_length: int | None = field(
default=512,
metadata={
"help": "Maximum length of the prompt. This argument is required if you want to use the default data "
"collator and your model is an encoder-decoder."
},
)
max_completion_length: int | None = field(
default=None,
metadata={
Expand Down
29 changes: 3 additions & 26 deletions trl/experimental/orpo/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,6 @@ def make_inputs_require_grad(module, input, output):
max_length = 512
else:
max_length = args.max_length
if args.max_prompt_length is None:
logger.warning(
"`max_prompt_length` is not set in the ORPOConfig's init"
" it will default to `128` by default, but you should do it yourself in the future.",
)
max_prompt_length = 128
else:
max_prompt_length = args.max_prompt_length

if args.max_completion_length is None and self.is_encoder_decoder:
logger.warning(
Expand Down Expand Up @@ -304,7 +296,6 @@ def make_inputs_require_grad(module, input, output):
self.max_length = max_length
self.generate_during_eval = args.generate_during_eval
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
self.max_prompt_length = max_prompt_length
self.truncation_mode = args.truncation_mode
self.processing_class = processing_class

Expand Down Expand Up @@ -492,23 +483,11 @@ def tokenize_row(self, feature, model: PreTrainedModel | nn.Module | None = None

longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))

# if combined sequence is too long, truncate the prompt
for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
if self.truncation_mode == "keep_start":
for k in ["prompt_input_ids", "prompt_attention_mask"]:
answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
elif self.truncation_mode == "keep_end":
for k in ["prompt_input_ids", "prompt_attention_mask"]:
answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
else:
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")

# if that's still too long, truncate the response
# if combined sequence is too long, truncate the response
for answer_tokens in [chosen_tokens, rejected_tokens]:
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
for k in ["input_ids", "attention_mask"]:
answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
answer_tokens[k] = answer_tokens[k][: self.max_length - longer_response_length]

# Create labels
chosen_sequence_tokens = {
Expand Down Expand Up @@ -543,9 +522,7 @@ def tokenize_row(self, feature, model: PreTrainedModel | nn.Module | None = None
rejected_tokens = self.processing_class(
rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
)
prompt_tokens = self.processing_class(
prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
)
prompt_tokens = self.processing_class(prompt, add_special_tokens=True)

batch["chosen_labels"] = chosen_tokens["input_ids"]
batch["rejected_labels"] = rejected_tokens["input_ids"]
Expand Down
Loading