-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdpo.py
197 lines (175 loc) · 7.42 KB
/
dpo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
# regular:
python examples/scripts/dpo.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 1e-3 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="dpo_anthropic_hh" \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns
# peft:
python examples/scripts/dpo.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 1e-3 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="dpo_anthropic_hh" \
--optim rmsprop \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns \
--use_peft \
--lora_r=16 \
--lora_alpha=16
"""
from dataclasses import dataclass, field
from typing import Dict, Optional
import torch
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments
from trl import DPOTrainer, ModelConfig, get_kbit_device_map, get_peft_config, get_quantization_config
@dataclass
class ScriptArguments:
beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
max_length: int = field(default=512, metadata={"help": "max length of each sample"})
max_prompt_length: int = field(default=128, metadata={"help": "max length of each sample's prompt"})
max_target_length: int = field(
default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"}
)
sanity_check: bool = field(default=True, metadata={"help": "only train on 1000 samples"})
ignore_bias_buffers: bool = field(
default=False,
metadata={
"help": "debug argument for distributed training;"
"fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)
generate_during_eval: bool = field(default=False, metadata={"help": "Generate during evaluation"})
def extract_anthropic_prompt(prompt_and_response):
"""Extract the anthropic prompt from a prompt and response pair."""
search_term = "\n\nAssistant:"
search_term_idx = prompt_and_response.rfind(search_term)
assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
return prompt_and_response[: search_term_idx + len(search_term)]
def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: Optional[str] = None) -> Dataset:
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.
The dataset is converted to a dictionary with the following structure:
{
'prompt': List[str],
'chosen': List[str],
'rejected': List[str],
}
Prompts should be structured as follows:
\n\nHuman: <prompt>\n\nAssistant:
Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
"""
dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
if sanity_check:
dataset = dataset.select(range(min(len(dataset), 1000)))
def split_prompt_and_responses(sample) -> Dict[str, str]:
prompt = extract_anthropic_prompt(sample["chosen"])
return {
"prompt": prompt,
"chosen": sample["chosen"][len(prompt) :],
"rejected": sample["rejected"][len(prompt) :],
}
return dataset.map(split_prompt_and_responses)
def get_ultrabin(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:
dataset = load_dataset("HuggingFaceH4/ultrafeedback_binarized", split=split, cache_dir=cache_dir)
if sanity_check:
dataset = dataset.select(range(min(len(dataset), 1000)))
def split_prompt_and_responses(sample) -> Dict[str, str]:
return {
"prompt": f"<|user|>\n{sample['prompt']}\n<|assistant|>\n",
"chosen": sample["chosen"][1]["content"],
"rejected": sample["rejected"][1]["content"],
}
return dataset.map(split_prompt_and_responses)
return dataset.from_dict(flat_data)
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, TrainingArguments, ModelConfig))
args, training_args, model_config = parser.parse_args_into_dataclasses()
################
# Model & Tokenizer
################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=True,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs)
peft_config = get_peft_config(model_config)
if peft_config is None:
model_ref = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs)
else:
model_ref = None
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if args.ignore_bias_buffers:
# torch distributed hack
model._ddp_params_and_buffers_to_ignore = [
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
]
################
# Dataset
################
#train_dataset = get_hh("train", sanity_check=args.sanity_check)
#eval_dataset = get_hh("test", sanity_check=args.sanity_check)
train_dataset = get_ultrabin("train_prefs", sanity_check=args.sanity_check)
eval_dataset = get_ultrabin("test_prefs", sanity_check=args.sanity_check)
################
# Training
################
trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=args.beta,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
max_length=args.max_length,
max_target_length=args.max_target_length,
max_prompt_length=args.max_prompt_length,
generate_during_eval=args.generate_during_eval,
peft_config=get_peft_config(model_config),
)
trainer.train()
trainer.save_model(training_args.output_dir)