21
21
import datetime
22
22
import os
23
23
import time
24
+ import json
24
25
25
26
import torch
26
27
import torch .nn as nn
27
28
from rouge import Rouge
28
- import dgl
29
- from tools import utils
30
- from tools .logger import *
29
+
30
+ from HiGraph import HSumGraph , HSumDocGraph
31
31
from Tester import SLTester
32
- from module .vocabulary import Vocab
33
- from module .embedding import Word_Embedding
34
32
from module .dataloader import ExampleSet , MultiExampleSet , graph_collate_fn
35
-
36
- from model .HiGraph import HSumGraph , HSumDocGraph
33
+ from module .embedding import Word_Embedding
34
+ from module .vocabulary import Vocab
35
+ from tools import utils
36
+ from tools .logger import *
37
37
38
38
39
- def load_test_model (model , model_name , eval_dir , save_root , gpu ):
39
+ def load_test_model (model , model_name , eval_dir , save_root ):
40
40
""" choose which model will be loaded for evaluation """
41
41
if model_name .startswith ('eval' ):
42
42
bestmodel_load_path = os .path .join (eval_dir , model_name [4 :])
@@ -51,30 +51,16 @@ def load_test_model(model, model_name, eval_dir, save_root, gpu):
51
51
raise ValueError ("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop" )
52
52
if not os .path .exists (bestmodel_load_path ):
53
53
logger .error ("[ERROR] Restoring %s for testing...The path %s does not exist!" , model_name , bestmodel_load_path )
54
- raise ValueError ( "[ERROR] Restoring %s for testing...The path %s does not exist!" % ( model_name , bestmodel_load_path ))
54
+ return None
55
55
logger .info ("[INFO] Restoring %s for testing...The path is %s" , model_name , bestmodel_load_path )
56
56
57
+ model .load_state_dict (torch .load (bestmodel_load_path ))
57
58
58
- if len (gpu ) > 1 :
59
- model .load_state_dict (torch .load (bestmodel_load_path ))
60
- model = model .module
61
- else :
62
- model .load_state_dict (torch .load (bestmodel_load_path ))
63
-
64
- if model == None :
65
- raise ValueError ("No model has been loaded for evaluation!" )
66
59
return model
67
60
68
61
69
62
70
63
def run_test (model , dataset , loader , model_name , hps ):
71
- """ evaluation phrase
72
- :param model: the model
73
- :param dataset: test dataset which includes text and summary
74
- :param loader: test dataset loader
75
- :param hps: hps for model
76
- :param model_name: model name to load
77
- """
78
64
test_dir = os .path .join (hps .save_root , "test" ) # make a subdir of the root dir for eval data
79
65
eval_dir = os .path .join (hps .save_root , "eval" )
80
66
if not os .path .exists (test_dir ) : os .makedirs (test_dir )
@@ -88,7 +74,7 @@ def run_test(model, dataset, loader, model_name, hps):
88
74
resfile = open (log_dir , "w" )
89
75
logger .info ("[INFO] Write the Evaluation into %s" , log_dir )
90
76
91
- model = load_test_model (model , model_name , eval_dir , hps .save_root , hps . gpu )
77
+ model = load_test_model (model , model_name , eval_dir , hps .save_root )
92
78
model .eval ()
93
79
94
80
iter_start_time = time .time ()
@@ -104,7 +90,7 @@ def run_test(model, dataset, loader, model_name, hps):
104
90
running_avg_loss = tester .running_avg_loss
105
91
106
92
if hps .save_label :
107
- import json
93
+ # save label and do not calculate rouge
108
94
json .dump (tester .extractLabel , resfile )
109
95
tester .SaveDecodeFile ()
110
96
logger .info (' | end of test | time: {:5.2f}s | ' .format ((time .time () - iter_start_time )))
@@ -137,7 +123,7 @@ def run_test(model, dataset, loader, model_name, hps):
137
123
138
124
139
125
def main ():
140
- parser = argparse .ArgumentParser (description = 'SumGraph Model' )
126
+ parser = argparse .ArgumentParser (description = 'HeterSumGraph Model' )
141
127
142
128
# Where to find data
143
129
parser .add_argument ('--data_dir' , type = str , default = 'data/CNNDM' , help = 'The dataset directory.' )
@@ -154,24 +140,24 @@ def main():
154
140
parser .add_argument ('--log_root' , type = str , default = 'log/' , help = 'Root directory for all logging.' )
155
141
156
142
# Hyperparameters
157
- parser .add_argument ('--gpu' , type = str , default = '0' , help = 'GPU ID to use. For cpu, set -1 [default: -1] ' )
143
+ parser .add_argument ('--gpu' , type = str , default = '0' , help = 'GPU ID to use' )
158
144
parser .add_argument ('--cuda' , action = 'store_true' , default = False , help = 'use cuda' )
159
- parser .add_argument ('--vocab_size' , type = int , default = 50000 , help = 'Size of vocabulary. These will be read from the vocabulary file in order. If the vocabulary file contains fewer words than this number, or if this number is set to 0, will take all words in the vocabulary file. ' )
160
- parser .add_argument ('--batch_size' , type = int , default = 32 , help = 'Mini batch size [default: 128 ]' )
161
- parser .add_argument ('--n_iter' , type = int , default = 1 , help = 'iteration hop ' )
145
+ parser .add_argument ('--vocab_size' , type = int , default = 50000 , help = 'Size of vocabulary.' )
146
+ parser .add_argument ('--batch_size' , type = int , default = 32 , help = 'Mini batch size [default: 32 ]' )
147
+ parser .add_argument ('--n_iter' , type = int , default = 1 , help = 'iteration ' )
162
148
163
149
parser .add_argument ('--word_embedding' , action = 'store_true' , default = True , help = 'whether to use Word embedding' )
164
- parser .add_argument ('--word_emb_dim' , type = int , default = 300 , help = 'Word embedding size [default: 200 ]' )
150
+ parser .add_argument ('--word_emb_dim' , type = int , default = 300 , help = 'Word embedding size [default: 300 ]' )
165
151
parser .add_argument ('--embed_train' , action = 'store_true' , default = False , help = 'whether to train Word embedding [default: False]' )
166
- parser .add_argument ('--feat_embed_size' , type = int , default = 50 , help = 'Word embedding size [default: 50]' )
167
- parser .add_argument ('--n_layers' , type = int , default = 1 , help = 'Number of deeplstm layers' )
152
+ parser .add_argument ('--feat_embed_size' , type = int , default = 50 , help = 'feature embedding size [default: 50]' )
153
+ parser .add_argument ('--n_layers' , type = int , default = 1 , help = 'Number of GAT layers [default: 1] ' )
168
154
parser .add_argument ('--lstm_hidden_state' , type = int , default = 128 , help = 'size of lstm hidden state' )
169
155
parser .add_argument ('--lstm_layers' , type = int , default = 2 , help = 'lstm layers' )
170
156
parser .add_argument ('--bidirectional' , action = 'store_true' , default = True , help = 'use bidirectional LSTM' )
171
157
parser .add_argument ('--n_feature_size' , type = int , default = 128 , help = 'size of node feature' )
172
- parser .add_argument ('--hidden_size' , type = int , default = 64 , help = 'hidden size [default: 512 ]' )
158
+ parser .add_argument ('--hidden_size' , type = int , default = 64 , help = 'hidden size [default: 64 ]' )
173
159
parser .add_argument ('--gcn_hidden_size' , type = int , default = 128 , help = 'hidden size [default: 64]' )
174
- parser .add_argument ('--ffn_inner_hidden_size' , type = int , default = 512 , help = 'PositionwiseFeedForward inner hidden size [default: 2048 ]' )
160
+ parser .add_argument ('--ffn_inner_hidden_size' , type = int , default = 512 , help = 'PositionwiseFeedForward inner hidden size [default: 512 ]' )
175
161
parser .add_argument ('--n_head' , type = int , default = 8 , help = 'multihead attention number [default: 8]' )
176
162
parser .add_argument ('--recurrent_dropout_prob' , type = float , default = 0.1 , help = 'recurrent dropout prob [default: 0.1]' )
177
163
parser .add_argument ('--atten_dropout_prob' , type = float , default = 0.1 ,help = 'attention dropout prob [default: 0.1]' )
@@ -181,7 +167,7 @@ def main():
181
167
parser .add_argument ('--doc_max_timesteps' , type = int , default = 50 , help = 'max length of documents (max timesteps of documents)' )
182
168
parser .add_argument ('--save_label' , action = 'store_true' , default = False , help = 'require multihead attention' )
183
169
parser .add_argument ('--limited' , action = 'store_true' , default = False , help = 'limited hypo length' )
184
- parser .add_argument ('--blocking' , action = 'store_true' , default = False , help = 'limited hypo length ' )
170
+ parser .add_argument ('--blocking' , action = 'store_true' , default = False , help = 'ngram blocking ' )
185
171
186
172
parser .add_argument ('-m' , type = int , default = 3 , help = 'decode summary length' )
187
173
@@ -221,33 +207,24 @@ def main():
221
207
hps = args
222
208
logger .info (hps )
223
209
210
+ test_w2s_path = os .path .join (args .cache_dir , "test.w2s.tfidf.jsonl" )
224
211
if hps .model == "HSG" :
225
212
model = HSumGraph (hps , embed )
226
213
logger .info ("[MODEL] HeterSumGraph " )
227
- train_w2s_path = os .path .join (args .cache_dir , "test.w2s.tfidf.jsonl" )
228
- dataset = ExampleSet (DATA_FILE , vocab , hps .doc_max_timesteps , hps .sent_max_len , FILTER_WORD , train_w2s_path )
214
+ dataset = ExampleSet (DATA_FILE , vocab , hps .doc_max_timesteps , hps .sent_max_len , FILTER_WORD , test_w2s_path )
229
215
loader = torch .utils .data .DataLoader (dataset , batch_size = hps .batch_size , shuffle = True , num_workers = 32 ,collate_fn = graph_collate_fn )
230
216
elif hps .model == "HDSG" :
231
217
model = HSumDocGraph (hps , embed )
232
218
logger .info ("[MODEL] HeterDocSumGraph " )
233
- train_w2s_path = os .path .join (args .cache_dir , "test.w2s.tfidf.jsonl" )
234
- train_w2d_path = os .path .join (args .cache_dir , "test.w2d.tfidf.jsonl" )
235
- dataset = MultiExampleSet (DATA_FILE , vocab , hps .doc_max_timesteps , hps .sent_max_len , FILTER_WORD , train_w2s_path , train_w2d_path )
219
+ test_w2d_path = os .path .join (args .cache_dir , "test.w2d.tfidf.jsonl" )
220
+ dataset = MultiExampleSet (DATA_FILE , vocab , hps .doc_max_timesteps , hps .sent_max_len , FILTER_WORD , test_w2s_path , test_w2d_path )
236
221
loader = torch .utils .data .DataLoader (dataset , batch_size = hps .batch_size , shuffle = True , num_workers = 32 ,collate_fn = graph_collate_fn )
237
222
else :
238
223
logger .error ("[ERROR] Invalid Model Type!" )
239
224
raise NotImplementedError ("Model Type has not been implemented" )
240
225
241
226
if args .cuda :
242
- model = model .cuda ()
243
-
244
- if len (args .gpu ) > 1 :
245
- gpuid = args .gpu .split (',' )
246
- gpuid = [int (s ) for s in gpuid ]
247
- model = nn .DataParallel (model ,device_ids = gpuid )
248
- logger .info ("[INFO] Use Multi-gpu: %s" , args .gpu )
249
- if hps .cuda :
250
- model = model .cuda ()
227
+ model .to (torch .device ("cuda:0" ))
251
228
logger .info ("[INFO] Use cuda" )
252
229
253
230
logger .info ("[INFO] Decoding..." )
0 commit comments