Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions tests/experimental/test_kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,7 @@ def test_tokenize_and_process_tokens(self):
"prefix": "",
"tokenizer": trainer.processing_class,
"max_length": trainer.max_length,
"truncation_mode": trainer.truncation_mode,
"label_pad_token_id": trainer.label_pad_token_id,
"max_prompt_length": trainer.max_prompt_length,
}
processed_dataset = tokenized_dataset.map(_process_tokens, fn_kwargs=fn_kwargs, num_proc=2)
assert processed_dataset["prompt"][:] == train_dataset["prompt"][:]
Expand Down
19 changes: 0 additions & 19 deletions trl/experimental/kto/kto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ class KTOConfig(TrainingArguments):
max_length (`int` or `None`, *optional*, defaults to `1024`):
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
to use the default data collator.
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt. This argument is required if you want to use the default data collator.
beta (`float`, *optional*, defaults to `0.1`):
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
reference model.
Expand All @@ -56,9 +54,6 @@ class KTOConfig(TrainingArguments):
Label pad token id. This argument is required if you want to use the default data collator.
padding_value (`int`, *optional*):
Padding value to use. If `None`, the padding value of the tokenizer is used.
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
This argument is required if you want to use the default data collator.
generate_during_eval (`bool`, *optional*, defaults to `False`):
If `True`, generates and logs completions from both the model and the reference model to W&B or Comet
during evaluation.
Expand Down Expand Up @@ -134,13 +129,6 @@ class KTOConfig(TrainingArguments):
default=1024,
metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."},
)
max_prompt_length: int | None = field(
default=512,
metadata={
"help": "Maximum length of the prompt. This argument is required if you want to use the default data "
"collator."
},
)
beta: float = field(
default=0.1,
metadata={
Expand Down Expand Up @@ -179,13 +167,6 @@ class KTOConfig(TrainingArguments):
default=None,
metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."},
)
truncation_mode: str = field(
default="keep_end",
metadata={
"help": "Truncation mode to use when the prompt is too long.",
"choices": ["keep_end", "keep_start"],
},
)
generate_during_eval: bool = field(
default=False,
metadata={
Expand Down
35 changes: 7 additions & 28 deletions trl/experimental/kto/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,7 @@ def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = None, **
"""Process tokens of a KTO specific dataset.

At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt +
completion responses is/are too long. First we truncate the prompt; if we're still too long, we truncate the
completion.
completion responses is/are too long. We truncate from the end (completion) to fit within max_length.

We also create the labels for the completion responses, which are of length equal to the sum of the length of the
prompt and the completion response, with label_pad_token_id for the prompt tokens.
Expand Down Expand Up @@ -199,20 +198,13 @@ def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = None, **
if len(all_tokens["answer_input_ids"]) > 0 and eos_token_id != all_tokens["answer_input_ids"][-1]:
max_length -= 1

# if combined sequence is too long (> max_length - 1 for BOS token - 1 for EOS), truncate the prompt
if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length:
for k in ["prompt_input_ids", "prompt_attention_mask"]:
if kwargs["truncation_mode"] == "keep_start":
all_tokens[k] = all_tokens[k][: kwargs["max_prompt_length"]]
elif kwargs["truncation_mode"] == "keep_end":
all_tokens[k] = all_tokens[k][-kwargs["max_prompt_length"] :]
else:
raise ValueError(f"Unknown truncation mode: {kwargs['truncation_mode']}")

# if that's still too long, truncate the response
if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length:
# if combined sequence is too long, truncate the completion (answer) from the end
prompt_length = len(all_tokens["prompt_input_ids"])
completion_length = len(all_tokens["answer_input_ids"])
if prompt_length + completion_length > max_length:
max_completion_length = max_length - prompt_length
for k in ["answer_input_ids", "answer_attention_mask"]:
all_tokens[k] = all_tokens[k][: max_length - kwargs["max_prompt_length"]]
all_tokens[k] = all_tokens[k][:max_completion_length]

# all input_ids and attention mask as is. We then check if we need to add BOS/EOS tokens
batch[f"{kwargs['prefix']}prompt_input_ids"] = all_tokens["prompt_input_ids"]
Expand Down Expand Up @@ -471,15 +463,6 @@ def make_inputs_require_grad(module, input, output):
if args.max_length is not None:
max_length = args.max_length

if args.max_prompt_length is None:
logger.warning(
"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init"
" it will be set to `128` by default, but you should do it yourself in the future.",
)
max_prompt_length = 128
if args.max_prompt_length is not None:
max_prompt_length = args.max_prompt_length

if data_collator is None:
data_collator = DPODataCollatorWithPadding(
pad_token_id=processing_class.pad_token_id,
Expand Down Expand Up @@ -509,8 +492,6 @@ def make_inputs_require_grad(module, input, output):
self.generate_during_eval = args.generate_during_eval
self.label_pad_token_id = args.label_pad_token_id
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
self.max_prompt_length = max_prompt_length
self.truncation_mode = args.truncation_mode
self.processing_class = processing_class
self.precompute_ref_log_probs = args.precompute_ref_log_probs

Expand Down Expand Up @@ -595,9 +576,7 @@ def make_inputs_require_grad(module, input, output):
"prefix": "",
"tokenizer": self.processing_class,
"max_length": self.max_length,
"truncation_mode": self.truncation_mode,
"label_pad_token_id": self.label_pad_token_id,
"max_prompt_length": self.max_prompt_length,
}

train_dataset = train_dataset.map(
Expand Down
Loading