|
21 | 21 | parser.add_argument('-vocab',default='data_sample/vocab/Vid2Name.json',type=str,help="entity dictionary")
|
22 | 22 | parser.add_argument('-A',default=False,type=bool,help="train with author info")
|
23 | 23 | parser.add_argument('-E',default=False,type=bool,help="evaluate with real body")
|
24 |
| -parser.add_argument('-abstract',default='data_sample/sample/abs0.csv',type=str,help="abstract to predict") |
25 |
| -parser.add_argument('-body',default='data_sample/sample/body0.csv',type=str,help="body to evaluate") |
| 24 | +parser.add_argument('-abstract',default='data_sample/sample_author/abs0.csv',type=str,help="abstract to predict") |
| 25 | +parser.add_argument('-authors',default='data_sample/sample_author/authors0.json',type=str,help="author information") |
| 26 | +parser.add_argument('-body',default='data_sample/sample_author/body0.csv',type=str,help="body to evaluate") |
26 | 27 | parser.add_argument('-path',default='model.h5',type=str,help="path for loading the model")
|
27 | 28 | parser.add_argument('-threshold',default=0.0,type=float,help="threshold of prediction, 0.0~1.0")
|
28 | 29 | opt=parser.parse_args()
|
@@ -51,8 +52,39 @@ def get_prediction(model,abs_path,in_dict):
|
51 | 52 | return abs_vec,list(pred)
|
52 | 53 |
|
53 | 54 |
|
54 |
| -def get_prediction_author(model,abs_path,in_dict,author_dict): |
55 |
| - pass |
| 55 | +def get_prediction_author(model,abs_path,authors_path,in_dict,author_dict): |
| 56 | + cc2vid_input=in_dict |
| 57 | + fa2vid=author_dict |
| 58 | + |
| 59 | + abs_vec=[0.0 for i in range(0,len(cc2vid_input))] |
| 60 | + abs_count=0.0 |
| 61 | + |
| 62 | + with open(abs_path,'r',encoding='utf-8') as cf: |
| 63 | + rd=csv.reader(cf) |
| 64 | + for item in rd: |
| 65 | + if item[0]=="Mention": |
| 66 | + continue |
| 67 | + try: |
| 68 | + abs_vec[cc2vid_input[item[1]]]+=1.0 |
| 69 | + abs_count+=1.0 |
| 70 | + except: |
| 71 | + pass |
| 72 | + if not abs_count==0.0: |
| 73 | + abs_vec=list(np.array(abs_vec)/abs_count) |
| 74 | + |
| 75 | + author_vec=[0.0 for i in range(0,len(fa2vid))] |
| 76 | + authors=util.load_sups(authors_path) |
| 77 | + for author in authors: |
| 78 | + try: |
| 79 | + author_vec[fa2vid[author]]=1.0 |
| 80 | + except: |
| 81 | + pass |
| 82 | + |
| 83 | + sample_input=np.array([abs_vec+author_vec]) |
| 84 | + pred=model.model.predict([sample_input[:,:len(cc2vid_input)],sample_input[:,len(cc2vid_input):]])[0] |
| 85 | + pred/=np.linalg.norm(pred) |
| 86 | + |
| 87 | + return abs_vec,list(pred) |
56 | 88 |
|
57 | 89 | def print_vec(prediction,entity_dict,threshold=0.0):
|
58 | 90 | for i,v in enumerate(prediction):
|
@@ -120,11 +152,16 @@ def main():
|
120 | 152 |
|
121 | 153 | if not author:
|
122 | 154 | bownn_model=BOWNN()
|
| 155 | + bownn_model.load_model(load_path) |
| 156 | + abs_vec,prediction=get_prediction(bownn_model,abs_path,in_dict) |
123 | 157 | else:
|
| 158 | + author_path=opt.author |
| 159 | + author_dict=util.load_sups(author_path) |
| 160 | + authors_path=opt.authors |
124 | 161 | bownn_model=BOWNN_author()
|
| 162 | + bownn_model.load_model(load_path) |
| 163 | + abs_vec,prediction=get_prediction_author(bownn_model,abs_path,authors_path,in_dict,author_dict) |
125 | 164 |
|
126 |
| - bownn_model.load_model(load_path) |
127 |
| - abs_vec,prediction=get_prediction(bownn_model,abs_path,in_dict) |
128 | 165 | print("Entity mentions in this abstract:")
|
129 | 166 | print_vec(abs_vec,entity_dict)
|
130 | 167 | print("\n")
|
|
0 commit comments