Skip to content

Feat: Add support for APO-zero in KTOTrainer #1952

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions examples/scripts/kto.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, setup_chat_format
from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_reformat_dpo_to_kto, setup_chat_format


# Define and parse arguments.
Expand Down Expand Up @@ -97,10 +97,17 @@ class ScriptArguments:
# Load the dataset
dataset = load_dataset(script_args.dataset_name)

# If needed, reformat a DPO-formatted dataset (prompt, chosen, rejected) to a KTO-format (prompt, completion, label)
dataset = maybe_reformat_dpo_to_kto(dataset, num_proc=kto_args.dataset_num_proc)

# Apply chat template
def format_dataset(example):
example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False)
example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
if isinstance(example["completion"], str):
example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False)
example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
else:
example["prompt"] = tokenizer.apply_chat_template(example["completion"][:-1], tokenize=False)
example["completion"] = tokenizer.apply_chat_template([example["completion"][-1]], tokenize=False)
return example

# Compute that only on the main process for faster data processing.
Expand Down
2 changes: 2 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
"MultitaskPromptTuningConfig",
"MultitaskPromptTuningInit",
],
"data_utils": ["maybe_reformat_dpo_to_kto"],
}

try:
Expand Down Expand Up @@ -160,6 +161,7 @@
from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config
from .commands.cli_utils import init_zero_verbose, SFTScriptArguments, DPOScriptArguments, TrlParser
from .data_utils import maybe_reformat_dpo_to_kto

try:
if not is_diffusers_available():
Expand Down
2 changes: 1 addition & 1 deletion trl/commands/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from rich.console import Console


SUPPORTED_COMMANDS = ["sft", "dpo", "chat"]
SUPPORTED_COMMANDS = ["sft", "dpo", "chat", "kto"]


def main():
Expand Down
56 changes: 56 additions & 0 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy

from datasets import DatasetDict


def _reformat_row_dpo_to_kto(row: dict):
# turn each paired row into two unpaired rows

chosen_row = {"prompt": row["prompt"], "completion": row["chosen"], "label": [True] * len(row["chosen"])}
rejected_row = {
"prompt": row["prompt"],
"completion": row["rejected"],
"label": [False] * len(row["chosen"]),
}
new_rows = {k: chosen_row[k] + rejected_row[k] for k in chosen_row.keys()}
return new_rows


def maybe_reformat_dpo_to_kto(dataset: DatasetDict, num_proc: int = None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For public methods, would you mind adding a docstring and a unit test please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

keys = list(dataset["train"].features.keys())

# check if the dataset is in the KTO format or needs to be reformatted
if "prompt" in keys and "completion" in keys and "label" in keys:
return dataset
elif "prompt" in keys and "rejected" in keys and "chosen" in keys:
# remove unnecessary fields
keys_to_remove = deepcopy(keys)
keys_to_remove.remove("prompt")
keys_to_remove.remove("chosen")
keys_to_remove.remove("rejected")
dataset = dataset.remove_columns(keys_to_remove)

# turn each DPO-formatted row into two KTO-formatted rows.
dataset = dataset.map(
_reformat_row_dpo_to_kto,
num_proc=num_proc,
batched=True,
remove_columns=["chosen", "rejected"],
desc="Reformatting Dataset from DPO format to KTO format.",
)
return dataset
else:
raise ValueError("Dataset format not compatible with KTO.")
11 changes: 10 additions & 1 deletion trl/trainer/kto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, Optional
from typing import Dict, Literal, Optional

from transformers import TrainingArguments

Expand All @@ -27,6 +27,11 @@ class KTOConfig(TrainingArguments):
command line.

Parameters:
loss_type (`str`, *optional*, defaults to `"kto"`):
The type of unpaired loss to use. Possible values are:

- `"kto"`: KTO loss from the [KTO](https://huggingface.co/papers/2402.01306) paper.
- `"apo_zero_unpaired"`: Unpaired variant of APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
max_length (`int`, *optional*, defaults to `None`):
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
max_prompt_length (`int`, *optional*, defaults to `None`):
Expand Down Expand Up @@ -60,6 +65,10 @@ class KTOConfig(TrainingArguments):
Number of processes to use for processing the datasets.
"""

loss_type: Literal[
"kto",
"apo_zero_unpaired",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is "unpaired" really necessary? As far as I understand, there is no such thing as "paired" version for kto, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

APO-zero does have a paired and unpaired variant, and you could definitely construct a paired variant of KTO.

We can remove "_unpaired" here since the KTOTrainer also implies it, but I thought it would be good for people to actively think about the distinction when selecting a loss.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes given we have an apo_zero loss also in the DPOTrainer, It's good to retain the _unpaired distinction IMO

Would you mind adding this loss term to the intergration tests here:

@parameterized.expand(

You might want to look at the DPO trainer for inspiration:

@parameterized.expand(

] = "kto"
max_length: Optional[int] = None
"""The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator."""
max_prompt_length: Optional[int] = None
Expand Down
Loading
Loading