Skip to content

Commit 19a73ec

Browse files
committed
repair comments and move HiGraph outside
1 parent 56a53a1 commit 19a73ec

File tree

4 files changed

+43
-71
lines changed

4 files changed

+43
-71
lines changed

model/HiGraph.py HiGraph.py

File renamed without changes.

evaluation.py

+28-51
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,22 @@
2121
import datetime
2222
import os
2323
import time
24+
import json
2425

2526
import torch
2627
import torch.nn as nn
2728
from rouge import Rouge
28-
import dgl
29-
from tools import utils
30-
from tools.logger import *
29+
30+
from HiGraph import HSumGraph, HSumDocGraph
3131
from Tester import SLTester
32-
from module.vocabulary import Vocab
33-
from module.embedding import Word_Embedding
3432
from module.dataloader import ExampleSet, MultiExampleSet, graph_collate_fn
35-
36-
from model.HiGraph import HSumGraph, HSumDocGraph
33+
from module.embedding import Word_Embedding
34+
from module.vocabulary import Vocab
35+
from tools import utils
36+
from tools.logger import *
3737

3838

39-
def load_test_model(model, model_name, eval_dir, save_root, gpu):
39+
def load_test_model(model, model_name, eval_dir, save_root):
4040
""" choose which model will be loaded for evaluation """
4141
if model_name.startswith('eval'):
4242
bestmodel_load_path = os.path.join(eval_dir, model_name[4:])
@@ -51,30 +51,16 @@ def load_test_model(model, model_name, eval_dir, save_root, gpu):
5151
raise ValueError("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop")
5252
if not os.path.exists(bestmodel_load_path):
5353
logger.error("[ERROR] Restoring %s for testing...The path %s does not exist!", model_name, bestmodel_load_path)
54-
raise ValueError("[ERROR] Restoring %s for testing...The path %s does not exist!" % (model_name, bestmodel_load_path))
54+
return None
5555
logger.info("[INFO] Restoring %s for testing...The path is %s", model_name, bestmodel_load_path)
5656

57+
model.load_state_dict(torch.load(bestmodel_load_path))
5758

58-
if len(gpu) > 1:
59-
model.load_state_dict(torch.load(bestmodel_load_path))
60-
model = model.module
61-
else:
62-
model.load_state_dict(torch.load(bestmodel_load_path))
63-
64-
if model == None:
65-
raise ValueError("No model has been loaded for evaluation!")
6659
return model
6760

6861

6962

7063
def run_test(model, dataset, loader, model_name, hps):
71-
""" evaluation phrase
72-
:param model: the model
73-
:param dataset: test dataset which includes text and summary
74-
:param loader: test dataset loader
75-
:param hps: hps for model
76-
:param model_name: model name to load
77-
"""
7864
test_dir = os.path.join(hps.save_root, "test") # make a subdir of the root dir for eval data
7965
eval_dir = os.path.join(hps.save_root, "eval")
8066
if not os.path.exists(test_dir) : os.makedirs(test_dir)
@@ -88,7 +74,7 @@ def run_test(model, dataset, loader, model_name, hps):
8874
resfile = open(log_dir, "w")
8975
logger.info("[INFO] Write the Evaluation into %s", log_dir)
9076

91-
model = load_test_model(model, model_name, eval_dir, hps.save_root, hps.gpu)
77+
model = load_test_model(model, model_name, eval_dir, hps.save_root)
9278
model.eval()
9379

9480
iter_start_time=time.time()
@@ -104,7 +90,7 @@ def run_test(model, dataset, loader, model_name, hps):
10490
running_avg_loss = tester.running_avg_loss
10591

10692
if hps.save_label:
107-
import json
93+
# save label and do not calculate rouge
10894
json.dump(tester.extractLabel, resfile)
10995
tester.SaveDecodeFile()
11096
logger.info(' | end of test | time: {:5.2f}s | '.format((time.time() - iter_start_time)))
@@ -137,7 +123,7 @@ def run_test(model, dataset, loader, model_name, hps):
137123

138124

139125
def main():
140-
parser = argparse.ArgumentParser(description='SumGraph Model')
126+
parser = argparse.ArgumentParser(description='HeterSumGraph Model')
141127

142128
# Where to find data
143129
parser.add_argument('--data_dir', type=str, default='data/CNNDM', help='The dataset directory.')
@@ -154,24 +140,24 @@ def main():
154140
parser.add_argument('--log_root', type=str, default='log/', help='Root directory for all logging.')
155141

156142
# Hyperparameters
157-
parser.add_argument('--gpu', type=str, default='0', help='GPU ID to use. For cpu, set -1 [default: -1]')
143+
parser.add_argument('--gpu', type=str, default='0', help='GPU ID to use')
158144
parser.add_argument('--cuda', action='store_true', default=False, help='use cuda')
159-
parser.add_argument('--vocab_size', type=int, default=50000, help='Size of vocabulary. These will be read from the vocabulary file in order. If the vocabulary file contains fewer words than this number, or if this number is set to 0, will take all words in the vocabulary file.')
160-
parser.add_argument('--batch_size', type=int, default=32, help='Mini batch size [default: 128]')
161-
parser.add_argument('--n_iter', type=int, default=1, help='iteration hop')
145+
parser.add_argument('--vocab_size', type=int, default=50000, help='Size of vocabulary.')
146+
parser.add_argument('--batch_size', type=int, default=32, help='Mini batch size [default: 32]')
147+
parser.add_argument('--n_iter', type=int, default=1, help='iteration ')
162148

163149
parser.add_argument('--word_embedding', action='store_true', default=True, help='whether to use Word embedding')
164-
parser.add_argument('--word_emb_dim', type=int, default=300, help='Word embedding size [default: 200]')
150+
parser.add_argument('--word_emb_dim', type=int, default=300, help='Word embedding size [default: 300]')
165151
parser.add_argument('--embed_train', action='store_true', default=False, help='whether to train Word embedding [default: False]')
166-
parser.add_argument('--feat_embed_size', type=int, default=50, help='Word embedding size [default: 50]')
167-
parser.add_argument('--n_layers', type=int, default=1, help='Number of deeplstm layers')
152+
parser.add_argument('--feat_embed_size', type=int, default=50, help='feature embedding size [default: 50]')
153+
parser.add_argument('--n_layers', type=int, default=1, help='Number of GAT layers [default: 1]')
168154
parser.add_argument('--lstm_hidden_state', type=int, default=128, help='size of lstm hidden state')
169155
parser.add_argument('--lstm_layers', type=int, default=2, help='lstm layers')
170156
parser.add_argument('--bidirectional', action='store_true', default=True, help='use bidirectional LSTM')
171157
parser.add_argument('--n_feature_size', type=int, default=128, help='size of node feature')
172-
parser.add_argument('--hidden_size', type=int, default=64, help='hidden size [default: 512]')
158+
parser.add_argument('--hidden_size', type=int, default=64, help='hidden size [default: 64]')
173159
parser.add_argument('--gcn_hidden_size', type=int, default=128, help='hidden size [default: 64]')
174-
parser.add_argument('--ffn_inner_hidden_size', type=int, default=512, help='PositionwiseFeedForward inner hidden size [default: 2048]')
160+
parser.add_argument('--ffn_inner_hidden_size', type=int, default=512, help='PositionwiseFeedForward inner hidden size [default: 512]')
175161
parser.add_argument('--n_head', type=int, default=8, help='multihead attention number [default: 8]')
176162
parser.add_argument('--recurrent_dropout_prob', type=float, default=0.1, help='recurrent dropout prob [default: 0.1]')
177163
parser.add_argument('--atten_dropout_prob', type=float, default=0.1,help='attention dropout prob [default: 0.1]')
@@ -181,7 +167,7 @@ def main():
181167
parser.add_argument('--doc_max_timesteps', type=int, default=50, help='max length of documents (max timesteps of documents)')
182168
parser.add_argument('--save_label', action='store_true', default=False, help='require multihead attention')
183169
parser.add_argument('--limited', action='store_true', default=False, help='limited hypo length')
184-
parser.add_argument('--blocking', action='store_true', default=False, help='limited hypo length')
170+
parser.add_argument('--blocking', action='store_true', default=False, help='ngram blocking')
185171

186172
parser.add_argument('-m', type=int, default=3, help='decode summary length')
187173

@@ -221,33 +207,24 @@ def main():
221207
hps = args
222208
logger.info(hps)
223209

210+
test_w2s_path = os.path.join(args.cache_dir, "test.w2s.tfidf.jsonl")
224211
if hps.model == "HSG":
225212
model = HSumGraph(hps, embed)
226213
logger.info("[MODEL] HeterSumGraph ")
227-
train_w2s_path = os.path.join(args.cache_dir, "test.w2s.tfidf.jsonl")
228-
dataset = ExampleSet(DATA_FILE, vocab, hps.doc_max_timesteps, hps.sent_max_len, FILTER_WORD, train_w2s_path)
214+
dataset = ExampleSet(DATA_FILE, vocab, hps.doc_max_timesteps, hps.sent_max_len, FILTER_WORD, test_w2s_path)
229215
loader = torch.utils.data.DataLoader(dataset, batch_size=hps.batch_size, shuffle=True, num_workers=32,collate_fn=graph_collate_fn)
230216
elif hps.model == "HDSG":
231217
model = HSumDocGraph(hps, embed)
232218
logger.info("[MODEL] HeterDocSumGraph ")
233-
train_w2s_path = os.path.join(args.cache_dir, "test.w2s.tfidf.jsonl")
234-
train_w2d_path = os.path.join(args.cache_dir, "test.w2d.tfidf.jsonl")
235-
dataset = MultiExampleSet(DATA_FILE, vocab, hps.doc_max_timesteps, hps.sent_max_len, FILTER_WORD, train_w2s_path, train_w2d_path)
219+
test_w2d_path = os.path.join(args.cache_dir, "test.w2d.tfidf.jsonl")
220+
dataset = MultiExampleSet(DATA_FILE, vocab, hps.doc_max_timesteps, hps.sent_max_len, FILTER_WORD, test_w2s_path, test_w2d_path)
236221
loader = torch.utils.data.DataLoader(dataset, batch_size=hps.batch_size, shuffle=True, num_workers=32,collate_fn=graph_collate_fn)
237222
else:
238223
logger.error("[ERROR] Invalid Model Type!")
239224
raise NotImplementedError("Model Type has not been implemented")
240225

241226
if args.cuda:
242-
model = model.cuda()
243-
244-
if len(args.gpu) > 1:
245-
gpuid = args.gpu.split(',')
246-
gpuid = [int(s) for s in gpuid]
247-
model = nn.DataParallel(model,device_ids=gpuid)
248-
logger.info("[INFO] Use Multi-gpu: %s", args.gpu)
249-
if hps.cuda:
250-
model = model.cuda()
227+
model.to(torch.device("cuda:0"))
251228
logger.info("[INFO] Use cuda")
252229

253230
logger.info("[INFO] Decoding...")

tools/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def eval_label(match_true, pred, true, total, match):
5050
recall = match_true / true
5151
F = 2 * precision * recall / (precision + recall)
5252
except ZeroDivisionError:
53-
F = 0.0
53+
accu, precision, recall, F = 0.0, 0.0, 0.0, 0.0
5454
logger.error("[Error] float division by zero")
5555
return accu, precision, recall, F
5656

train.py

+14-19
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,19 @@
2424
import shutil
2525
import time
2626

27-
import torch
27+
import dgl
2828
import numpy as np
29+
import torch
2930
from rouge import Rouge
3031

31-
import dgl
32-
33-
from tools.logger import *
32+
from HiGraph import HSumGraph, HSumDocGraph
3433
from Tester import SLTester
35-
from module.dataloader import LoadHiExampleSet
3634
from module.dataloader import ExampleSet, MultiExampleSet, graph_collate_fn
3735
from module.embedding import Word_Embedding
3836
from module.vocabulary import Vocab
37+
from tools.logger import *
3938

40-
from model.HiGraph import HSumGraph, HSumDocGraph
39+
_DEBUG_FLAG_ = False
4140

4241

4342
def save_model(model, save_file):
@@ -88,8 +87,6 @@ def run_training(model, train_loader, valid_loader, valset, hps, train_dir):
8887
'''
8988
logger.info("[INFO] Starting run_training")
9089

91-
# optimizer = torch.optim.Adamax(filter(lambda p: p.requires_grad, model.parameters()), lr=hps.lr, betas=(0.9, 0.98),
92-
# eps=1e-09)
9390
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=hps.lr)
9491

9592

@@ -141,10 +138,11 @@ def run_training(model, train_loader, valid_loader, valset, hps, train_dir):
141138
epoch_loss += float(loss.data)
142139

143140
if i % 100 == 0:
144-
for name, param in model.named_parameters():
145-
if param.requires_grad:
146-
logger.debug(name)
147-
logger.debug(param.grad.data.sum())
141+
if _DEBUG_FLAG_:
142+
for name, param in model.named_parameters():
143+
if param.requires_grad:
144+
logger.debug(name)
145+
logger.debug(param.grad.data.sum())
148146
logger.info(' | end of iter {:3d} | time: {:5.2f}s | train loss {:5.4f} | '
149147
.format(i, (time.time() - iter_start_time),float(train_loss / 100)))
150148
train_loss = 0.0
@@ -226,8 +224,6 @@ def run_eval(model, loader, valset, hps, best_loss, best_F, non_descent_cnt, sav
226224
tester.getMetric()
227225
F = tester.labelMetric
228226

229-
# If running_avg_loss is best so far, save this checkpoint (early stopping).
230-
# These checkpoints will appear as bestmodel-<iteration_number> in the eval dir
231227
if best_loss is None or running_avg_loss < best_loss:
232228
bestmodel_save_path = os.path.join(eval_dir, 'bestmodel_%d' % (saveNo % 3)) # this is where checkpoints of best models are saved
233229
if best_loss is not None:
@@ -288,7 +284,7 @@ def main():
288284
parser.add_argument('--word_embedding', action='store_true', default=True, help='whether to use Word embedding [default: True]')
289285
parser.add_argument('--word_emb_dim', type=int, default=300, help='Word embedding size [default: 300]')
290286
parser.add_argument('--embed_train', action='store_true', default=False,help='whether to train Word embedding [default: False]')
291-
parser.add_argument('--feat_embed_size', type=int, default=50, help='Word embedding size [default: 50]')
287+
parser.add_argument('--feat_embed_size', type=int, default=50, help='feature embedding size [default: 50]')
292288
parser.add_argument('--n_layers', type=int, default=1, help='Number of GAT layers [default: 1]')
293289
parser.add_argument('--lstm_hidden_state', type=int, default=128, help='size of lstm hidden state [default: 128]')
294290
parser.add_argument('--lstm_layers', type=int, default=2, help='Number of lstm layers [default: 2]')
@@ -348,25 +344,24 @@ def main():
348344
hps = args
349345
logger.info(hps)
350346

347+
train_w2s_path = os.path.join(args.cache_dir, "train.w2s.tfidf.jsonl")
348+
val_w2s_path = os.path.join(args.cache_dir, "val.w2s.tfidf.jsonl")
349+
351350
if hps.model == "HSG":
352351
model = HSumGraph(hps, embed)
353352
logger.info("[MODEL] HeterSumGraph ")
354-
train_w2s_path = os.path.join(args.cache_dir, "train.w2s.tfidf.jsonl")
355353
dataset = ExampleSet(DATA_FILE, vocab, hps.doc_max_timesteps, hps.sent_max_len, FILTER_WORD, train_w2s_path)
356354
train_loader = torch.utils.data.DataLoader(dataset, batch_size=hps.batch_size, shuffle=True, num_workers=32,collate_fn=graph_collate_fn)
357355
del dataset
358-
val_w2s_path = os.path.join(args.cache_dir, "val.w2s.tfidf.jsonl")
359356
valid_dataset = ExampleSet(VALID_FILE, vocab, hps.doc_max_timesteps, hps.sent_max_len, FILTER_WORD, val_w2s_path)
360357
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=hps.batch_size, shuffle=False, collate_fn=graph_collate_fn, num_workers=32)
361358
elif hps.model == "HDSG":
362359
model = HSumDocGraph(hps, embed)
363360
logger.info("[MODEL] HeterDocSumGraph ")
364-
train_w2s_path = os.path.join(args.cache_dir, "train.w2s.tfidf.jsonl")
365361
train_w2d_path = os.path.join(args.cache_dir, "train.w2d.tfidf.jsonl")
366362
dataset = MultiExampleSet(DATA_FILE, vocab, hps.doc_max_timesteps, hps.sent_max_len, FILTER_WORD, train_w2s_path, train_w2d_path)
367363
train_loader = torch.utils.data.DataLoader(dataset, batch_size=hps.batch_size, shuffle=True, num_workers=32,collate_fn=graph_collate_fn)
368364
del dataset
369-
val_w2s_path = os.path.join(args.cache_dir, "val.w2s.tfidf.jsonl")
370365
val_w2d_path = os.path.join(args.cache_dir, "val.w2d.tfidf.jsonl")
371366
valid_dataset = MultiExampleSet(VALID_FILE, vocab, hps.doc_max_timesteps, hps.sent_max_len, FILTER_WORD, val_w2s_path, val_w2d_path)
372367
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=hps.batch_size, shuffle=False,collate_fn=graph_collate_fn, num_workers=32) # Shuffle Must be False for ROUGE evaluation

0 commit comments

Comments
 (0)