-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathapi_client.py
73 lines (53 loc) · 2.14 KB
/
api_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import asyncio
import string
from semantic_parsing_with_constrained_lm.src.semantic_parsing_with_constrained_lm.lm_openai_gpt3 import GPT3Client
import json
from omegaconf import OmegaConf
import more_itertools
import tqdm
import hydra.utils as hu
import hydra
import random
import os
@hydra.main(config_path="configs",config_name="client")
def main(cfg):
print(cfg)
client = GPT3Client(api_key=os.environ["OPENAI_TOKEN"])
async def get_pred(entry_list):
prompt = [x['enc_text'] for x in entry_list]
args = {
"prompt": prompt,
"max_tokens": 280,
"stop":["\n"],
"echo":False,
"logprobs":1
}
results = (
await client.completions_rate_limited(cfg.engine, args) # type: ignore
).json()
for i,x in enumerate(entry_list):
x['generated'] = results['choices'][i]['text']
return entry_list
async def run(data_list):
task_list = []
for i,prompt in enumerate(more_itertools.chunked(data_list,cfg.batch_size)):
task = asyncio.create_task(get_pred(prompt))
task_list.append(task)
responses = [await f
for f in tqdm.tqdm(asyncio.as_completed(task_list), total=len(task_list))]
return responses
def run_main(cfg):
dataset_reader = hu.instantiate(cfg.dataset_reader)
idx_list = list(range(len(dataset_reader)))
random.Random(42).shuffle(idx_list)
idx_list = idx_list[:1000]
data_list = [dataset_reader[x]['metadata'] for x in (idx_list if "slice" not in cfg
else idx_list[cfg.slice*100:(cfg.slice+1)*100])]
res = asyncio.run(run(data_list))
res = list(more_itertools.collapse(res,levels=1))
with open(cfg.output_file, "w") as f:
json.dump(res,f)
run_main(cfg)
#python gpt3_client.py prompt_file=$PWD/data/random_smcalflow_valid.json task_name=smcalflow output_file=$PWD/data/test.json engine=davinci-codex +stop=true
if __name__ == "__main__":
main()