Skip to content

Commit ca6585c

Browse files
committed
author net predict
1 parent 2865750 commit ca6585c

File tree

1 file changed

+43
-6
lines changed

1 file changed

+43
-6
lines changed

predict.py

+43-6
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
parser.add_argument('-vocab',default='data_sample/vocab/Vid2Name.json',type=str,help="entity dictionary")
2222
parser.add_argument('-A',default=False,type=bool,help="train with author info")
2323
parser.add_argument('-E',default=False,type=bool,help="evaluate with real body")
24-
parser.add_argument('-abstract',default='data_sample/sample/abs0.csv',type=str,help="abstract to predict")
25-
parser.add_argument('-body',default='data_sample/sample/body0.csv',type=str,help="body to evaluate")
24+
parser.add_argument('-abstract',default='data_sample/sample_author/abs0.csv',type=str,help="abstract to predict")
25+
parser.add_argument('-authors',default='data_sample/sample_author/authors0.json',type=str,help="author information")
26+
parser.add_argument('-body',default='data_sample/sample_author/body0.csv',type=str,help="body to evaluate")
2627
parser.add_argument('-path',default='model.h5',type=str,help="path for loading the model")
2728
parser.add_argument('-threshold',default=0.0,type=float,help="threshold of prediction, 0.0~1.0")
2829
opt=parser.parse_args()
@@ -51,8 +52,39 @@ def get_prediction(model,abs_path,in_dict):
5152
return abs_vec,list(pred)
5253

5354

54-
def get_prediction_author(model,abs_path,in_dict,author_dict):
55-
pass
55+
def get_prediction_author(model,abs_path,authors_path,in_dict,author_dict):
56+
cc2vid_input=in_dict
57+
fa2vid=author_dict
58+
59+
abs_vec=[0.0 for i in range(0,len(cc2vid_input))]
60+
abs_count=0.0
61+
62+
with open(abs_path,'r',encoding='utf-8') as cf:
63+
rd=csv.reader(cf)
64+
for item in rd:
65+
if item[0]=="Mention":
66+
continue
67+
try:
68+
abs_vec[cc2vid_input[item[1]]]+=1.0
69+
abs_count+=1.0
70+
except:
71+
pass
72+
if not abs_count==0.0:
73+
abs_vec=list(np.array(abs_vec)/abs_count)
74+
75+
author_vec=[0.0 for i in range(0,len(fa2vid))]
76+
authors=util.load_sups(authors_path)
77+
for author in authors:
78+
try:
79+
author_vec[fa2vid[author]]=1.0
80+
except:
81+
pass
82+
83+
sample_input=np.array([abs_vec+author_vec])
84+
pred=model.model.predict([sample_input[:,:len(cc2vid_input)],sample_input[:,len(cc2vid_input):]])[0]
85+
pred/=np.linalg.norm(pred)
86+
87+
return abs_vec,list(pred)
5688

5789
def print_vec(prediction,entity_dict,threshold=0.0):
5890
for i,v in enumerate(prediction):
@@ -120,11 +152,16 @@ def main():
120152

121153
if not author:
122154
bownn_model=BOWNN()
155+
bownn_model.load_model(load_path)
156+
abs_vec,prediction=get_prediction(bownn_model,abs_path,in_dict)
123157
else:
158+
author_path=opt.author
159+
author_dict=util.load_sups(author_path)
160+
authors_path=opt.authors
124161
bownn_model=BOWNN_author()
162+
bownn_model.load_model(load_path)
163+
abs_vec,prediction=get_prediction_author(bownn_model,abs_path,authors_path,in_dict,author_dict)
125164

126-
bownn_model.load_model(load_path)
127-
abs_vec,prediction=get_prediction(bownn_model,abs_path,in_dict)
128165
print("Entity mentions in this abstract:")
129166
print_vec(abs_vec,entity_dict)
130167
print("\n")

0 commit comments

Comments
 (0)