-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathfind_knn.py
73 lines (59 loc) · 2.59 KB
/
find_knn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import hydra.utils as hu
import hydra
from hydra.core.hydra_config import HydraConfig
import torch
import tqdm
from torch.utils.data import DataLoader
from src.data.collators import DataCollatorWithPaddingAndCuda
import faiss
import numpy as np
import json
class KNNFinder:
def __init__(self,cfg) -> None:
self.cuda_device = cfg.cuda_device
self.dataset_reader = hu.instantiate(cfg.dataset_reader)
if cfg.instruction is False:
co = DataCollatorWithPaddingAndCuda(tokenizer=self.dataset_reader.tokenizer,device = self.cuda_device)
self.dataloader = DataLoader(self.dataset_reader,batch_size=cfg.batch_size,collate_fn=co)
else:
self.dataloader = DataLoader(self.dataset_reader,batch_size=cfg.batch_size)
self.model = hu.instantiate(cfg.model).to(self.cuda_device)
self.index = faiss.read_index(cfg.index_path)
self.output_path = cfg.output_path
self.is_train = cfg.dataset_split=="train"
self.instruction = cfg.instruction
def forward(self):
res_list = []
for i,entry in enumerate(tqdm.tqdm(self.dataloader)):
with torch.no_grad():
res = self.model(**entry)
if self.instruction:
res = res
entry['metadata']['id'] = entry['metadata']['id'].tolist()
res_list.extend([{"res":r,"metadata": {'id': i}} for r,i in zip(res,entry['metadata']['id'])])
else:
res = res.cpu().detach().numpy()
res_list.extend([{"res": r, "metadata": m} for r, m in zip(res, entry['metadata'])])
return res_list
def search(self,entry,k=50):
res = np.expand_dims(entry['res'],axis=0)
near_ids = self.index.search(res, k+1)[1][0]
near_ids = near_ids[1:] if self.is_train else near_ids
return [{"id":int(a)} for a in near_ids[:k]]
def find(self):
res_list = self.forward()
data_list = []
for entry in tqdm.tqdm(res_list):
data = self.dataset_reader.task.dataset[entry['metadata']['id']]
data['ctxs'] = self.search(entry)
data_list.append(data)
with open(self.output_path,"w") as f:
json.dump(data_list,f)
#python find_knn.py index_path=$PWD/data/break_mpnet_q.bin output_path=$PWD/data/break_mpnet_q_prompts.json cuda_device=1 setup_type=q dataset_split=validation task_name=break
@hydra.main(config_path="configs",config_name="knn_finder")
def main(cfg):
print(cfg)
knn_finder = KNNFinder(cfg)
knn_finder.find()
if __name__ == "__main__":
main()