-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Feat: Add support for APO-zero in KTOTrainer #1952
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
db9cf6a
ddb8ac5
6a56cb7
4b30563
082c996
06fa0f8
da13cd6
1075c4b
1a9bcaf
4e32eed
13617e9
ad33bc7
a541923
67b52b9
1f3ab5a
eec33ba
3ddf513
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Copyright 2022 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from copy import deepcopy | ||
|
||
from datasets import DatasetDict | ||
|
||
|
||
def _reformat_row_dpo_to_kto(row: dict): | ||
# turn each paired row into two unpaired rows | ||
|
||
chosen_row = {"prompt": row["prompt"], "completion": row["chosen"], "label": [True] * len(row["chosen"])} | ||
rejected_row = { | ||
"prompt": row["prompt"], | ||
"completion": row["rejected"], | ||
"label": [False] * len(row["chosen"]), | ||
} | ||
new_rows = {k: chosen_row[k] + rejected_row[k] for k in chosen_row.keys()} | ||
return new_rows | ||
|
||
|
||
def maybe_reformat_dpo_to_kto(dataset: DatasetDict, num_proc: int = None): | ||
keys = list(dataset["train"].features.keys()) | ||
|
||
# check if the dataset is in the KTO format or needs to be reformatted | ||
if "prompt" in keys and "completion" in keys and "label" in keys: | ||
return dataset | ||
elif "prompt" in keys and "rejected" in keys and "chosen" in keys: | ||
# remove unnecessary fields | ||
keys_to_remove = deepcopy(keys) | ||
keys_to_remove.remove("prompt") | ||
keys_to_remove.remove("chosen") | ||
keys_to_remove.remove("rejected") | ||
dataset = dataset.remove_columns(keys_to_remove) | ||
|
||
# turn each DPO-formatted row into two KTO-formatted rows. | ||
dataset = dataset.map( | ||
_reformat_row_dpo_to_kto, | ||
num_proc=num_proc, | ||
batched=True, | ||
remove_columns=["chosen", "rejected"], | ||
desc="Reformatting Dataset from DPO format to KTO format.", | ||
) | ||
return dataset | ||
else: | ||
raise ValueError("Dataset format not compatible with KTO.") |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -12,7 +12,7 @@ | |||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
from dataclasses import dataclass | ||||||
from typing import Dict, Optional | ||||||
from typing import Dict, Literal, Optional | ||||||
|
||||||
from transformers import TrainingArguments | ||||||
|
||||||
|
@@ -27,6 +27,11 @@ class KTOConfig(TrainingArguments): | |||||
command line. | ||||||
|
||||||
Parameters: | ||||||
loss_type (`str`, *optional*, defaults to `"kto"`): | ||||||
The type of unpaired loss to use. Possible values are: | ||||||
|
||||||
- `"kto"`: KTO loss from the [KTO](https://huggingface.co/papers/2402.01306) paper. | ||||||
- `"apo_zero_unpaired"`: Unpaired variant of APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper. | ||||||
max_length (`int`, *optional*, defaults to `None`): | ||||||
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator. | ||||||
max_prompt_length (`int`, *optional*, defaults to `None`): | ||||||
|
@@ -60,6 +65,10 @@ class KTOConfig(TrainingArguments): | |||||
Number of processes to use for processing the datasets. | ||||||
""" | ||||||
|
||||||
loss_type: Literal[ | ||||||
"kto", | ||||||
"apo_zero_unpaired", | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is "unpaired" really necessary? As far as I understand, there is no such thing as "paired" version for kto, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. APO-zero does have a paired and unpaired variant, and you could definitely construct a paired variant of KTO. We can remove "_unpaired" here since the KTOTrainer also implies it, but I thought it would be good for people to actively think about the distinction when selecting a loss. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes given we have an Would you mind adding this loss term to the intergration tests here: Line 76 in 47ab034
You might want to look at the DPO trainer for inspiration: Line 251 in 47ab034
|
||||||
] = "kto" | ||||||
max_length: Optional[int] = None | ||||||
"""The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.""" | ||||||
max_prompt_length: Optional[int] = None | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For public methods, would you mind adding a docstring and a unit test please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!