-
Notifications
You must be signed in to change notification settings - Fork 16
/
utils.py
312 lines (252 loc) · 10.2 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
# Copyright (C) 2023 ByteDance. All Rights Reserved.
#
# This software is released under the MIT License.
# https://opensource.org/licenses/MIT
import random
import numpy as np
import pandas as pd
import torch
from datasets import Dataset
from transformers import DataCollatorForLanguageModeling
torch.manual_seed(8888)
np.random.seed(8888)
random.seed(8888)
def create_pku_dataloader_from_dataset(tokenizer, dataset, fraction=1.0, batch_size=4):
"""
Given the PKU dataset, create the dataloader on the unlearned harmful Q&A pairs.
Args:
tokenizer: Tokenizer.
dataset: Loaded PKU dataset.
fraction: <1 will do downsampling.
batch_size: Batch size.
Returns:
Data loader of PKU harmful Q&A pairs.
"""
# Preproccess function.
def preproccess(examples):
"""
Input: Dict[List]
Output: Dict[List]
"""
results = {"input_ids": [], "attention_mask": [], "start_locs": []}
for i in range(len(examples["prompt"])):
# Subsample if needed.
if random.random() > fraction:
continue
prompt = examples["prompt"][i]
response_list = []
# Add only bad samples.
if not examples["is_response_0_safe"][i]:
response_list.append(examples["response_0"][i])
if not examples["is_response_1_safe"][i]:
response_list.append(examples["response_1"][i])
# Add all responses to results or skip if none.
for response in response_list:
text = f"### Question: {prompt}\n ### Answer: {response}"
tokenized = tokenizer(text, truncation=True, padding="max_length")
results["input_ids"].append(tokenized["input_ids"])
results["attention_mask"].append(tokenized["attention_mask"])
# Calculate start idx for answer
test_text = f"### Question: {prompt}\n ### Answer: "
test_tokenized = tokenizer(
test_text, truncation=True, padding="max_length"
)
results["start_locs"].append(len(test_tokenized["input_ids"]) - 1)
return results
# Need to drop all original columns to emit more than one row for each original row https://huggingface.co/docs/datasets/about_map_batch#input-size-output-size.
dataset = dataset.map(
preproccess,
batched=True,
remove_columns=[
"prompt",
"response_0",
"response_1",
"is_response_0_safe",
"is_response_1_safe",
"better_response_id",
"safer_response_id",
],
)
dataset.set_format(
type="torch", columns=["input_ids", "attention_mask", "start_locs"]
)
# Add labels and make it data loader.
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, collate_fn=data_collator
)
return dataloader
def create_truthfulqa_dataloader(tokenizer, batch_size=4):
"""
Create the TruthfulQA dataloader for the normal data.
Args:
tokenizer: Tokenizer.
batch_size: Batch size.
Returns:
Data loader of TruthfulQA normal Q&A pairs.
"""
df = pd.read_csv("data/TruthfulQA.csv")
questions, good_answers = df["Question"].values, df["Best Answer"].values
data = {"input_ids": [], "attention_mask": []}
for question, good_answer in zip(questions, good_answers):
text = f"### Question: {question}\n ### Answer: {good_answer}"
tokenized = tokenizer(text, truncation=True, padding="max_length")
data["input_ids"].append(tokenized["input_ids"])
data["attention_mask"].append(tokenized["attention_mask"])
dataset = Dataset.from_dict(data)
# Split train/val/test = 0.7/0.1/0.2.
train_len = int(0.7 * len(dataset))
val_len = int(0.1 * len(dataset))
test_len = len(dataset) - train_len - val_len
train_data, val_data, test_data = torch.utils.data.random_split(
dataset, [train_len, val_len, test_len]
)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
train_dataloader = torch.utils.data.DataLoader(
train_data, batch_size=batch_size, collate_fn=data_collator, shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(
val_data, batch_size=batch_size, collate_fn=data_collator, shuffle=True
)
test_dataloader = torch.utils.data.DataLoader(
test_data, batch_size=batch_size, collate_fn=data_collator, shuffle=True
)
return train_dataloader, val_dataloader, test_dataloader
def get_truthfulQA_answers_plaintext(tqa_file_path="data/TruthfulQA.csv"):
"""
Get the plain text of TruthfulQA's answers used for random mismatch.
Args:
None
Returns:
A list of answer text in TruthfulQA.
"""
ans_names = ["Best Answer", "Correct Answers", "Incorrect Answers"]
df = pd.read_csv(tqa_file_path)
all_ans = []
for ans_name in ans_names:
answers = df[ans_name].values
if ans_name == "Best Answer":
all_ans.extend(answers)
# Split "Correct Answers" and "Incorrect Answers"by ";".
else:
for answer in answers:
ans_list = answer.split(";")
for ans in ans_list:
all_ans.append(ans.strip())
return all_ans
def compute_kl(pretrained_model, current_model, batch, device):
"""
Compute *forward* KL as the normal utility loss.
Args:
pretrained_model: reference model which is the pretrained (original) model.
current_model: The current unlearning model.
batch: A batch of normal data.
device: GPU device.
Returns:
The KL loss.
"""
normal_outputs = current_model(
batch["input_ids"].to(device),
attention_mask=batch["attention_mask"].to(device),
labels=batch["labels"].to(device),
)
with torch.no_grad():
pretrained_outputs = pretrained_model(
batch["input_ids"].to(device),
attention_mask=batch["attention_mask"].to(device),
labels=batch["labels"].to(device),
)
# P: pretrained model; Q: current model.
prob_p = torch.nn.functional.softmax(pretrained_outputs.logits, -1)
prob_q = torch.nn.functional.softmax(normal_outputs.logits, -1)
loss = -(prob_p * torch.log(prob_q + 1e-12)).sum(-1).mean()
return loss
def get_answer_loss(operation, batch, model, device="cuda:0"):
"""
Compute the loss on the answer (i.e. y) part.
Args:
operation: either "ga" (gradient ascent) or "gd" (gradient descent).
batch: A batch of data.
model: The unlearned model.
device: GPU device.
Returns:
The loss.
"""
assert operation in ["ga", "gd"], "Operation must be either GA or GD."
input_ids, attention_mask, start_locs, labels = (
batch["input_ids"].to(device),
batch["attention_mask"].to(device),
batch["start_locs"],
batch["labels"].to(device),
)
outputs = model(input_ids, attention_mask=attention_mask)
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
# Shift one to predict next token.
shift_logits = outputs.logits[:, :-1, :]
shift_labels = labels[:, 1:]
losses = []
for bid in range(input_ids.shape[0]):
one_inp, one_st = input_ids[bid], start_locs[bid]
# GA or GD.
position_loss = loss_fct(shift_logits[bid], shift_labels[bid])
if operation == "ga": # Negative the direction for GA.
position_loss = -position_loss
# Simply put equal weights on all answers.
position_weight = torch.zeros_like(one_inp)
assert len(position_weight) == len(position_loss) + 1
position_weight[one_st:] = 1 # only focus on answer part
# Ignore the padding part.
position_weight[one_inp == 1] = 0
if position_weight.sum() > 0:
position_weight = position_weight / position_weight.sum()
one_loss = (position_weight[:-1] * position_loss).sum()
losses.append(one_loss)
final_loss = torch.stack(losses).mean()
return final_loss
def get_rand_ans_loss(bad_batch, tokenizer, normal_ans, model, K=5, device="cuda:0"):
"""
Compute the loss of the random mismatch.
Args:
bad_batch: A batch of forgetting data.
tokenizer: The tokenizer.
normal_ans: A list of random answers.
model: unlearned model.
K: How many random answers sampled for each forgetting sample.
device: GPU device.
Returns:
The random mismatch loss.
"""
bad_input_ids = bad_batch["input_ids"].to(device)
rand_ans_list = random.sample(normal_ans, k=K)
batch_random_features = []
for batch_idx in range(bad_input_ids.shape[0]):
single_input_id = bad_input_ids[batch_idx, :]
ori_text = tokenizer.decode(single_input_id)
# Get question.
question = ori_text.split("###")[1].split("Question:")[-1].strip()
question_prefix = f"### Question: {question}\n ### Answer: "
tokenized_question_prefix = tokenizer(
question_prefix, truncation=True, padding="max_length"
)
# Doesn't need to minus 1 because there's a starting token in the beginning.
start_loc = len(tokenized_question_prefix)
# Get random answer.
for rand_ans in rand_ans_list:
random_sample = f"{question_prefix}{rand_ans}"
# Tokenize.
tokenized_rs = tokenizer(
random_sample, truncation=True, padding="max_length"
)
batch_random_features.append(
{
"input_ids": tokenized_rs["input_ids"],
"attention_mask": tokenized_rs["attention_mask"],
"start_locs": start_loc,
}
)
# Batchify.
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
batch_random = data_collator(batch_random_features)
# GD on answer.
random_loss = get_answer_loss("gd", batch_random, model, device=device)
return random_loss