-
Notifications
You must be signed in to change notification settings - Fork 0
/
merge_model_gen_data.py
56 lines (44 loc) · 1.85 KB
/
merge_model_gen_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import json
import sys
from loguru import logger
import time
class Dataset():
def __init__(self,model,path):
self.model = model
with open(path,'r',encoding='utf-8') as f:
self.data = f.read().strip().split("\n")
def __getitem__(self,index):
return json.loads(self.data[index])
def __len__(self):
return len(self.data)
if __name__ == "__main__":
path_and_models = open(sys.argv[1],'r',encoding='utf-8').read().strip().split("\n")
path_and_models = [tuple(path.split()) for path in path_and_models] #(model_tpye,file_path)
current_index = 0
total_model = len(path_and_models)
datasets = []
for (model,path) in path_and_models:
datasets.append(Dataset(model,path))
# make shure all dataset has same len
data_len = len(datasets[0])
for dataset in datasets:
assert len(dataset) == data_len,f"{len(dataset)}, {data_len}"
merge_outputs = []
while(current_index < data_len):
merge_output = {}
article = datasets[0][current_index]['article']
merge_output['_id'] = f"{current_index}"
merge_output['_models'] = [d.model for d in datasets]
merge_output['article'] = article
merge_output['questionGroups'] = []
label_questions = datasets[0][current_index]['label_questions']
merge_output['questionGroups'].append(label_questions)
for dataset in datasets:
_article = dataset[current_index]['article']
assert article[:20] == _article[:20],'article is not equal'
merge_output['questionGroups'].append(dataset[current_index]['ga']) # use only ga
merge_outputs.append(merge_output)
current_index += 1
# dump
with open('./merge_model_output.json','w',encoding='utf-8') as f:
f.write(json.dumps(merge_outputs))