-
Notifications
You must be signed in to change notification settings - Fork 30
/
main.py
210 lines (162 loc) · 6.86 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
__author__ = 'qiao'
'''
GeneGPT: teach LLMs to use NCBI API
'''
import json
import openai
import config
openai.api_key = config.API_KEY
import os
import re
import sys
import time
import urllib.request
def call_api(url):
time.sleep(1)
url = url.replace(' ', '+')
print(url)
req = urllib.request.Request(url)
with urllib.request.urlopen(req) as response:
call = response.read()
return call
def get_prompt_header(mask):
'''
mask: [1/0 x 6], denotes whether each prompt component is used
output: prompt
'''
url_1 = 'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db=gene&retmax=5&retmode=json&sort=relevance&term=LMP10'
call_1 = call_api(url_1)
url_2 = 'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?db=gene&retmax=5&retmode=json&id=19171,5699,8138'
call_2 = call_api(url_2)
url_3 = 'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi?db=snp&retmax=10&retmode=json&id=1217074595'
call_3 = call_api(url_3)
url_4 = 'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db=omim&retmax=20&retmode=json&sort=relevance&term=Meesmann+corneal+dystrophy'
call_4 = call_api(url_4)
url_5 = 'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi?db=omim&retmax=20&retmode=json&id=618767,601687,300778,148043,122100'
call_5 = call_api(url_5)
url_6 = 'https://blast.ncbi.nlm.nih.gov/blast/Blast.cgi?CMD=Put&PROGRAM=blastn&MEGABLAST=on&DATABASE=nt&FORMAT_TYPE=XML&QUERY=ATTCTGCCTTTAGTAATTTGATGACAGAGACTTCTTGGGAACCACAGCCAGGGAGCCACCCTTTACTCCACCAACAGGTGGCTTATATCCAATCTGAGAAAGAAAGAAAAAAAAAAAAGTATTTCTCT&HITLIST_SIZE=5'
call_6 = call_api(url_6)
rid = re.search('RID = (.*)\n', call_6.decode('utf-8')).group(1)
url_7 = f'https://blast.ncbi.nlm.nih.gov/blast/Blast.cgi?CMD=Get&FORMAT_TYPE=Text&RID={rid}'
time.sleep(30)
call_7 = call_api(url_7)
prompt = ''
prompt += 'Hello. Your task is to use NCBI Web APIs to answer genomic questions.\n'
#prompt += 'There are two types of Web APIs you can use: Eutils and BLAST.\n\n'
if mask[0]:
# Doc 0 is about Eutils
prompt += 'You can call Eutils by: "[https://eutils.ncbi.nlm.nih.gov/entrez/eutils/{esearch|efetch|esummary}.fcgi?db={gene|snp|omim}&retmax={}&{term|id}={term|id}]".\n'
prompt += 'esearch: input is a search term and output is database id(s).\n'
prompt += 'efectch/esummary: input is database id(s) and output is full records or summaries that contain name, chromosome location, and other information.\n'
prompt += 'Normally, you need to first call esearch to get the database id(s) of the search term, and then call efectch/esummary to get the information with the database id(s).\n'
prompt += 'Database: gene is for genes, snp is for SNPs, and omim is for genetic diseases.\n\n'
if mask[1]:
# Doc 1 is about BLAST
prompt += 'For DNA sequences, you can use BLAST by: "[https://blast.ncbi.nlm.nih.gov/blast/Blast.cgi?CMD={Put|Get}&PROGRAM=blastn&MEGABLAST=on&DATABASE=nt&FORMAT_TYPE={XML|Text}&QUERY={sequence}&HITLIST_SIZE={max_hit_size}]".\n'
prompt += 'BLAST maps a specific DNA {sequence} to its chromosome location among different specices.\n'
prompt += 'You need to first PUT the BLAST request and then GET the results using the RID returned by PUT.\n\n'
if any(mask[2:]):
prompt += 'Here are some examples:\n\n'
if mask[2]:
# Example 1 is from gene alias task
prompt += f'Question: What is the official gene symbol of LMP10?\n'
prompt += f'[{url_1}]->[{call_1}]\n'
prompt += f'[{url_2}]->[{call_2}]\n'
prompt += f'Answer: PSMB10\n\n'
if mask[3]:
# Example 2 is from SNP gene task
prompt += f'Question: Which gene is SNP rs1217074595 associated with?\n'
prompt += f'[{url_3}]->[{call_3}]\n'
prompt += f'Answer: LINC01270\n\n'
if mask[4]:
# Example 3 is from gene disease association
prompt += f'Question: What are genes related to Meesmann corneal dystrophy?\n'
prompt += f'[{url_4}]->[{call_4}]\n'
prompt += f'[{url_5}]->[{call_5}]\n'
prompt += f'Answer: KRT12, KRT3\n\n'
if mask[5]:
# Example 4 is for BLAST
prompt += f'Question: Align the DNA sequence to the human genome:ATTCTGCCTTTAGTAATTTGATGACAGAGACTTCTTGGGAACCACAGCCAGGGAGCCACCCTTTACTCCACCAACAGGTGGCTTATATCCAATCTGAGAAAGAAAGAAAAAAAAAAAAGTATTTCTCT\n'
prompt += f'[{url_6}]->[{rid}]\n'
prompt += f'[{url_7}]->[{call_7}]\n'
prompt += f'Answer: chr15:91950805-91950932\n\n'
return prompt
if __name__ == '__main__':
# rough number of chars for truncating
# codex accepts 8k tokens ~ 18k chars
cut_length = 18000
# str_mask is a string of six 0/1 marking whether a in-context learning component is used
# six digits correspond to Dc. 1-2, Dm. 1-4
str_mask = sys.argv[1]
mask = [bool(int(x)) for x in str_mask]
prompt = get_prompt_header(mask)
# results are saved in the dir of six digits
if not os.path.isdir(str_mask):
os.mkdir(str_mask)
# initialize
prev_call = time.time()
qas = json.load(open('data/geneturing.json'))
for task, info in qas.items():
if os.path.exists(os.path.join(str_mask, f'{task}.json')):
# continue if task is done
preds = json.load(open(os.path.join(str_mask, f'{task}.json')))
if len(preds) == 50: continue
output = []
print(f'Doing task {task}')
for question, answer in info.items():
print('---New Instance---')
print(question)
q_prompt = prompt + f'Question: {question}\n'
# save the prompting logs
prompts = []
# record API call times
num_calls = 0
while True:
if len(q_prompt) > cut_length:
# truncate from the start
q_prompt = q_prompt[len(q_prompt) - cut_length:]
body = {
"model": "code-davinci-002",
"prompt": q_prompt,
"max_tokens": 512,
"temperature": 0,
"stop": ['->', '\n\nQuestion'],
"n": 1
}
delta = time.time() - prev_call
# codex has a rate limite of 20 requests / min
# it's a workaround
if delta < 3.1:
time.sleep(3.1 - delta)
try:
prev_call = time.time()
response = openai.Completion.create(**body)
except openai.error.InvalidRequestError:
output.append([question, answer, 'lengthError', prompts])
break
text = response['choices'][0]['text']
print(text)
num_calls += 1
prompts.append([q_prompt, text])
url_regex = r'\[(https?://[^\[\]]+)\]'
matches = re.findall(url_regex, text)
if matches:
url = matches[0]
# wait till the BLAST is done on NCBI server
if 'blast' in url and 'Get' in url: time.sleep(30)
call = call_api(url)
if 'blast' in url and 'Put' in url:
rid = re.search('RID = (.*)\n', call.decode('utf-8')).group(1)
call = rid
if len(call) > 10000:
call = call[:10000]
q_prompt = f'{q_prompt}{text}->[{call}]\n'
else:
output.append([question, answer, text, prompts])
break
# prevent dead loops
if num_calls >= 10:
output.append([question, answer, 'numError', prompts])
break
with open(os.path.join(str_mask, f'{task}.json'), 'w') as f:
json.dump(output, f, indent=4)