Skip to content

Commit 7acb9c2

Browse files
KarelDOkarel-contextualqgallouedeclewtun
authored
Feat: Add support for APO-zero in KTOTrainer (#1952)
* feat : add kto command * feat : add support for apo loss in KTO Trainer * feat : make kto script compatible with dpo-formatted datasets * fix: lint data utils * add loss_type in kto test * fix: data utils docstrings * fix: add dataset reformat test * fix: lint tests * fix: only reference kl_logps if needed --------- Co-authored-by: Karel D'Oosterlinck <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: lewtun <[email protected]>
1 parent 6840380 commit 7acb9c2

File tree

8 files changed

+340
-135
lines changed

8 files changed

+340
-135
lines changed

Diff for: examples/scripts/kto.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
from datasets import load_dataset
6060
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
6161

62-
from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, setup_chat_format
62+
from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_reformat_dpo_to_kto, setup_chat_format
6363

6464

6565
# Define and parse arguments.
@@ -97,10 +97,17 @@ class ScriptArguments:
9797
# Load the dataset
9898
dataset = load_dataset(script_args.dataset_name)
9999

100+
# If needed, reformat a DPO-formatted dataset (prompt, chosen, rejected) to a KTO-format (prompt, completion, label)
101+
dataset = maybe_reformat_dpo_to_kto(dataset, num_proc=kto_args.dataset_num_proc)
102+
100103
# Apply chat template
101104
def format_dataset(example):
102-
example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False)
103-
example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
105+
if isinstance(example["completion"], str):
106+
example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False)
107+
example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
108+
else:
109+
example["prompt"] = tokenizer.apply_chat_template(example["completion"][:-1], tokenize=False)
110+
example["completion"] = tokenizer.apply_chat_template([example["completion"][-1]], tokenize=False)
104111
return example
105112

106113
# Compute that only on the main process for faster data processing.

Diff for: tests/test_dataset_reformat.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import unittest
15+
16+
from datasets import Dataset, DatasetDict
17+
18+
from trl.data_utils import maybe_reformat_dpo_to_kto
19+
20+
21+
class MaybeReformatDPOToKTOTester(unittest.TestCase):
22+
def setUp(self):
23+
# Create a sample DPO-formatted dataset for testing
24+
self.dpo_data = {
25+
"prompt": ["What is AI?", "Define machine learning."],
26+
"chosen": ["AI is artificial intelligence.", "Machine learning is a subset of AI."],
27+
"rejected": ["AI is a computer.", "Machine learning is a program."],
28+
}
29+
self.dpo_dataset = DatasetDict({"train": Dataset.from_dict(self.dpo_data)})
30+
31+
# Create a sample KTO-formatted dataset for testing
32+
self.kto_data = {
33+
"prompt": ["What is AI?", "Define machine learning.", "What is AI?", "Define machine learning."],
34+
"completion": [
35+
"AI is artificial intelligence.",
36+
"Machine learning is a subset of AI.",
37+
"AI is a computer.",
38+
"Machine learning is a program.",
39+
],
40+
"label": [True, True, False, False],
41+
}
42+
self.kto_dataset = DatasetDict({"train": Dataset.from_dict(self.kto_data)})
43+
44+
def test_dpo_to_kto_conversion(self):
45+
# Test that a DPO-formatted dataset is correctly reformatted to KTO format
46+
reformatted_dataset = maybe_reformat_dpo_to_kto(self.dpo_dataset)
47+
self.assertEqual(
48+
reformatted_dataset["train"].to_dict(),
49+
self.kto_dataset["train"].to_dict(),
50+
"The DPO-formatted dataset was not correctly reformatted to KTO format.",
51+
)
52+
53+
def test_already_kto_format(self):
54+
# Test that a KTO-formatted dataset remains unchanged
55+
reformatted_dataset = maybe_reformat_dpo_to_kto(self.kto_dataset)
56+
self.assertEqual(
57+
reformatted_dataset["train"].to_dict(),
58+
self.kto_dataset["train"].to_dict(),
59+
"The KTO-formatted dataset should remain unchanged.",
60+
)
61+
62+
def test_invalid_format(self):
63+
# Test that a dataset with an incompatible format raises a ValueError
64+
invalid_data = {
65+
"input": ["What is AI?", "Define machine learning."],
66+
"output": ["AI is artificial intelligence.", "Machine learning is a subset of AI."],
67+
}
68+
invalid_dataset = DatasetDict({"train": Dataset.from_dict(invalid_data)})
69+
70+
with self.assertRaises(ValueError):
71+
maybe_reformat_dpo_to_kto(invalid_dataset)

Diff for: tests/test_kto_trainer.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,17 @@ def _init_dummy_dataset(self):
7575

7676
@parameterized.expand(
7777
[
78-
["gpt2", True, True],
79-
["gpt2", True, False],
80-
# ["t5", True],
81-
["gpt2", False, True],
82-
["gpt2", False, False],
83-
# ["t5", False],
78+
["gpt2", "kto", True, True],
79+
["gpt2", "kto", True, False],
80+
["gpt2", "kto", False, True],
81+
["gpt2", "kto", False, False],
82+
["gpt2", "apo_zero_unpaired", True, True],
83+
["gpt2", "apo_zero_unpaired", True, False],
84+
["gpt2", "apo_zero_unpaired", False, True],
85+
["gpt2", "apo_zero_unpaired", False, False],
8486
]
8587
)
86-
def test_kto_trainer(self, name, pre_compute, eval_dataset):
88+
def test_kto_trainer(self, name, loss_type, pre_compute, eval_dataset):
8789
with tempfile.TemporaryDirectory() as tmp_dir:
8890
training_args = KTOConfig(
8991
output_dir=tmp_dir,
@@ -95,6 +97,7 @@ def test_kto_trainer(self, name, pre_compute, eval_dataset):
9597
eval_strategy="steps",
9698
beta=0.1,
9799
precompute_ref_log_probs=pre_compute,
100+
loss_type=loss_type,
98101
report_to="none",
99102
)
100103

Diff for: trl/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
"MultitaskPromptTuningConfig",
8282
"MultitaskPromptTuningInit",
8383
],
84+
"data_utils": ["maybe_reformat_dpo_to_kto"],
8485
}
8586

8687
try:
@@ -162,6 +163,7 @@
162163
from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback
163164
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config
164165
from .commands.cli_utils import init_zero_verbose, SFTScriptArguments, DPOScriptArguments, TrlParser
166+
from .data_utils import maybe_reformat_dpo_to_kto
165167

166168
try:
167169
if not is_diffusers_available():

Diff for: trl/commands/cli.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from rich.console import Console
2222

2323

24-
SUPPORTED_COMMANDS = ["sft", "dpo", "chat"]
24+
SUPPORTED_COMMANDS = ["sft", "dpo", "chat", "kto"]
2525

2626

2727
def main():

Diff for: trl/data_utils.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2022 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from copy import deepcopy
15+
16+
from datasets import DatasetDict
17+
18+
19+
def _reformat_row_dpo_to_kto(row: dict):
20+
"""Turn a DPO-formatted dataset row into two KTO-formatted rows."""
21+
22+
chosen_row = {"prompt": row["prompt"], "completion": row["chosen"], "label": [True] * len(row["chosen"])}
23+
rejected_row = {
24+
"prompt": row["prompt"],
25+
"completion": row["rejected"],
26+
"label": [False] * len(row["chosen"]),
27+
}
28+
new_rows = {k: chosen_row[k] + rejected_row[k] for k in chosen_row.keys()}
29+
return new_rows
30+
31+
32+
def maybe_reformat_dpo_to_kto(dataset: DatasetDict, num_proc: int = None):
33+
"""
34+
Reformat a dataset from the DPO format to the KTO format if necessary.
35+
36+
This function checks whether the input dataset is already in the KTO format (containing "prompt", "completion", and "label" fields).
37+
If the dataset is in DPO format (with "prompt", "chosen", and "rejected" fields), it converts it to KTO format by:
38+
- Removing any unnecessary columns.
39+
- Reformatting each row to create a unified format suitable for KTO training.
40+
41+
Args:
42+
dataset (DatasetDict): The dataset to potentially reformat.
43+
num_proc (int, optional): The number of processes to use for multiprocessing during dataset transformation. Defaults to None.
44+
45+
Returns:
46+
DatasetDict: The reformatted dataset, if conversion was needed; otherwise, the original dataset.
47+
48+
Raises:
49+
ValueError: If the dataset format is not compatible with KTO or DPO.
50+
"""
51+
keys = list(dataset["train"].features.keys())
52+
53+
# check if the dataset is in the KTO format or needs to be reformatted
54+
if "prompt" in keys and "completion" in keys and "label" in keys:
55+
return dataset
56+
elif "prompt" in keys and "rejected" in keys and "chosen" in keys:
57+
# remove unnecessary fields
58+
keys_to_remove = deepcopy(keys)
59+
keys_to_remove.remove("prompt")
60+
keys_to_remove.remove("chosen")
61+
keys_to_remove.remove("rejected")
62+
dataset = dataset.remove_columns(keys_to_remove)
63+
64+
# turn each DPO-formatted row into two KTO-formatted rows.
65+
dataset = dataset.map(
66+
_reformat_row_dpo_to_kto,
67+
num_proc=num_proc,
68+
batched=True,
69+
remove_columns=["chosen", "rejected"],
70+
desc="Reformatting Dataset from DPO format to KTO format.",
71+
)
72+
return dataset
73+
else:
74+
raise ValueError("Dataset format not compatible with KTO.")

Diff for: trl/trainer/kto_config.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from dataclasses import dataclass
15-
from typing import Dict, Optional
15+
from typing import Dict, Literal, Optional
1616

1717
from transformers import TrainingArguments
1818

@@ -27,6 +27,11 @@ class KTOConfig(TrainingArguments):
2727
command line.
2828
2929
Parameters:
30+
loss_type (`str`, *optional*, defaults to `"kto"`):
31+
The type of unpaired loss to use. Possible values are:
32+
33+
- `"kto"`: KTO loss from the [KTO](https://huggingface.co/papers/2402.01306) paper.
34+
- `"apo_zero_unpaired"`: Unpaired variant of APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
3035
max_length (`int`, *optional*, defaults to `None`):
3136
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
3237
max_prompt_length (`int`, *optional*, defaults to `None`):
@@ -60,6 +65,10 @@ class KTOConfig(TrainingArguments):
6065
Number of processes to use for processing the datasets.
6166
"""
6267

68+
loss_type: Literal[
69+
"kto",
70+
"apo_zero_unpaired",
71+
] = "kto"
6372
max_length: Optional[int] = None
6473
"""The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator."""
6574
max_prompt_length: Optional[int] = None

0 commit comments

Comments
 (0)