-
Notifications
You must be signed in to change notification settings - Fork 0
/
autocompressor_inference.py
174 lines (156 loc) · 7.17 KB
/
autocompressor_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from transformers import AutoTokenizer
from auto_compressor import AutoCompressorModel
import argparse
import json
import re
import time
import utils
import pdb
def main():
args = arg_parser()
print('*****************************')
print(args)
print('*****************************')
tokenizer = AutoTokenizer.from_pretrained("princeton-nlp/AutoCompressor-2.7b-6k")
model = AutoCompressorModel.from_pretrained("princeton-nlp/AutoCompressor-2.7b-6k").eval()
try:
with open(args.demo_path, "r") as file:
demo = file.read()
except FileNotFoundError:
print("Your demo path doesn't exist. Please try another path.")
dataloader = utils.create_dataloader(args)
correct = 0
wrong_list = []
if args.qes_limit == 0:
args.qes_limit = len(dataloader)
answer_list = []
gt_list = []
start = time.time()
for count, qa in enumerate(dataloader):
if args.qes_limit is not None and count == args.qes_limit:
break
if args.dataset == 'boolq':
message = (demo + '\nFollow the given examples and answer the following question with true or false: ' + qa['question'] + ' Answer is: ')
elif args.dataset == 'multiple_rc':
message = (demo + '\nFollow the given examples and read the given passage carefully. A student has given his answer to the question based on the passage. Your task is to respond whether the student\'s answer is correct or wrong.\n' + 'User: ' + qa['question'] + '\nResponse: The student\'s answer is ')
else:
raise NotImplemented
print(message)
message_tokens = tokenizer(message, return_tensors="pt").input_ids
output = model(message_tokens)
last = tokenizer.decode(output.logits[0,-1].argmax())
if last == '\n' or last == ' ':
answer = tokenizer.decode(output.logits[0,-2].argmax()).lower()
else:
answer = last.lower()
print(f"answer is: {answer}")
print(f"ground truth is: {qa['answer']}")
answer_list.append(answer)
gt_list.append(qa['answer'])
if args.dataset == 'multiple_rc':
if qa['answer'].lower() in answer:
print('yes')
correct += 1
else:
wrong_list.append({'question': qa['question'], 'answer': answer, 'ground_truth': qa['answer']})
elif args.dataset == 'boolq':
if qa['answer'] == True:
if 'yes' in answer or 'true' in answer:
correct += 1
else:
wrong_list.append({'question': qa['question'], 'answer': answer, 'ground_truth': qa['answer']})
else:
if 'no' in answer or 'false' in answer:
correct += 1
else:
wrong_list.append({'question': qa['question'], 'answer': answer, 'ground_truth': qa['answer']})
end = time.time()
print(f"Answer list = {answer_list}")
print(f"GT list = {gt_list}")
print(f"Total correct number: {correct}")
print(f"Correct Percentage: {correct / args.qes_limit}")
print(f"Execution time: {end - start} seconds")
summary_path = f"./summaries/one_prompt_round/{args.qes_limit}_{args.demo_path.split('/')[-1]}"
with open(summary_path, "a") as f:
f.write(f"Total correct number: {correct}\n")
f.write(f"Correct Percentage: {correct / args.qes_limit}\n")
f.write(f"Execution time: {end - start} seconds")
wrong_list_path = f"./wrong_lists/{args.qes_limit}_{args.demo_path.split('/')[-1]}"
with open(wrong_list_path, "a") as f:
f.write(json.dumps(wrong_list, indent=4))
def arg_parser():
parser = argparse.ArgumentParser(description="Inference with selected prompts.")
parser.add_argument("--random_seed", type=int, default=42, help="random seed")
parser.add_argument(
"--dataset", type=str, default="gsm8k", choices=["multiple_rc", "boolq", "squad", "gsm8k", "svamp", "aqua", "csqa", "asdiv", "last_letters", "addsub", "singleeq", "strategyqa", "multiarith"], help="dataset to inference"
)
parser.add_argument(
"--dataset_path", type=str, default="./dataset/GSM8K/"
)
parser.add_argument(
"--trainset_path", type=str, default="./dataset/GSM8K/train.jsonl", help="prompts to use"
)
parser.add_argument(
"--demo_path", type=str, default="./distilled_demos/gsm8k_Llama-2-13b-chat-hf_4_2_trainsplit_42.txt", help="path to distilled demos"
)
parser.add_argument(
"--QA_dir", type=str, default="./QA_records/", help="output directory for QA records"
)
parser.add_argument(
"--wrong_list_dir", type=str, default="./wrong_lists/", help="output directory for wrong lists"
)
parser.add_argument(
"--max_length_cot", type=int, default=512, help="maximum length of output tokens by model for reasoning extraction"
)
parser.add_argument(
"--qes_limit", type=int, default=10, help="whether to limit test dataset size. if 0, the dataset size is unlimited and we use all the samples in the dataset for testing."
)
parser.add_argument(
"--multipath", type=int, default=1, help="self-consistency path num"
)
parser.add_argument(
"--concat_length", type=int, default=4, help='Used for task last_letters, indicates length of last letter to concat, i.e. Elon Musk -> nk, use concat length of 2'
)
parser.add_argument(
"--use_code_style_prompt", type=bool, default=False, help='Use code-style prompt as mentioned in paper for last_letters dataset'
)
parser.add_argument(
"--distill", type=bool, default=False, help="whether load training set"
)
args = parser.parse_args()
if args.multipath > 1:
args.temperature = 0.7
else:
args.temperature = 0
print(f"Temperature: {args.temperature}")
if args.dataset == "gsm8k":
args.dataset_path = "./dataset/GSM8K/test.jsonl"
elif args.dataset == "svamp":
args.dataset_path = "./dataset/SVAMP/SVAMP.json"
elif args.dataset == "asdiv":
args.dataset_path = "./dataset/ASDiv/ASDiv.json"
elif args.dataset == "aqua":
args.dataset_path = "./dataset/AQuA/test.json"
elif args.dataset == "csqa":
args.dataset_path = "./dataset/CSQA/dev_rand_split.jsonl"
elif args.dataset == "strategyqa":
args.dataset_path = "./dataset/strategyQA/task.json"
elif args.dataset == "last_letters":
args.dataset_path = "./dataset/last_letters/last_letters_test.json"
elif args.dataset == "addsub":
args.dataset_path = "./dataset/MAWPS/AddSub.json"
elif args.dataset == "singleeq":
args.dataset_path = "./dataset/MAWPS/SingleEq.json"
elif args.dataset == "multiarith":
args.dataset_path = "./dataset/MAWPS/MultiArith.json"
elif args.dataset == 'squad':
args.dataset_path = "squad_v2"
elif args.dataset == 'boolq':
args.dataset_path = "boolq"
elif args.dataset == "multiple_rc":
args.dataset_path == "./dataset/MultiRC/"
else:
raise ValueError("dataset is not properly defined ...")
return args
if __name__ == "__main__":
main()