-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
executable file
·110 lines (91 loc) · 5.18 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Main file for 11-747 Project. By Alex Coda, Andrew Runge, & Liz Salesky."""
import argparse
import pickle
import torch
import random
import logging
import logging.config
#local imports
from encdec import RNNEncoder, RNNDecoder, EncDec, AttnDecoder, CondGruDecoder
from preprocessing import input_reader, create_vocab
from utils import use_cuda
from training import MTTrainer
from train_monitor import TrainMonitor
def main(args):
print(args.config)
params = __import__(args.config.replace('.py',''))
logging.config.fileConfig('config/logging.conf', disable_existing_loggers=False, defaults={'filename': '{}/training.log'.format(params.OUTPUT_PATH)})
logger = logging.getLogger(__name__)
logger.info("Use CUDA: {}".format(use_cuda)) #set automatically in utils
if args.debug:
params.train_src = 'data/examples/debug.en'
params.train_tgt = 'data/examples/debug.cs'
params.dev_src = 'data/examples/debug.en'
params.dev_tgt = 'data/examples/debug.cs'
params.tst_src = 'data/examples/debug.en'
params.tst_tgt = 'data/examples/debug.cs'
if params.fixed_seeds:
torch.manual_seed(69)
if use_cuda:
torch.cuda.manual_seed(69)
random.seed(69)
max_num_sents = int(args.maxnumsents)
# Read in or create vocabs
if args.srcvocab is not None:
src_vocab = pickle.load(open(args.srcvocab, 'rb'))
else:
src_vocab = create_vocab(params.train_src, params.src_lang, max_num_sents, params.max_sent_length, max_vocab_size=100000)
src_vocab.save(params.src_vocab + ".pkl")
if args.tgtvocab is not None:
tgt_vocab = pickle.load(open(args.tgtvocab, 'rb'))
else:
tgt_vocab = create_vocab(params.train_tgt, params.tgt_lang, max_num_sents, params.max_sent_length, max_vocab_size=100000)
tgt_vocab.save(params.tgt_vocab + ".pkl")
input_size = src_vocab.vocab_size()
output_size = tgt_vocab.vocab_size()
logger.info("src vocab size: {}".format(input_size))
logger.info("tgt vocab size: {}".format(output_size))
# Read in data
train_sents = input_reader(params.train_src, params.train_tgt, params.src_lang, params.tgt_lang, max_num_sents, params.max_sent_length, src_vocab, tgt_vocab, sort=True)
dev_sents_unsorted = input_reader(params.dev_src, params.dev_tgt, params.src_lang, params.tgt_lang, max_num_sents, params.max_sent_length,
src_vocab, tgt_vocab, filt=False)
dev_sents_sorted = input_reader(params.dev_src, params.dev_tgt, params.src_lang, params.tgt_lang, max_num_sents, params.max_sent_length,
src_vocab, tgt_vocab, sort=True, filt=False)
tst_sents = input_reader(params.tst_src, params.tst_tgt, params.src_lang, params.tgt_lang, max_num_sents, params.max_sent_length,
src_vocab, tgt_vocab, filt=False)
# Initialize our model
if args.model is not None:
model = torch.load(args.model)
else:
enc = RNNEncoder(vocab_size=input_size, embed_size=params.embed_size,
hidden_size=params.enc_hidden_size, rnn_type='GRU',
num_layers=1, bidirectional=params.bi_enc)
if params.cond_gru_dec:
dec = CondGruDecoder(enc_size=params.enc_hidden_size, vocab_size=output_size,
embed_size=params.embed_size, hidden_size=params.dec_hidden_size, bidirectional_enc=params.bi_enc)
else:
dec = AttnDecoder(enc_size=params.enc_hidden_size, vocab_size=output_size,
embed_size=params.embed_size, hidden_size=params.dec_hidden_size,
rnn_type='GRU', num_layers=1, bidirectional_enc=params.bi_enc,
tgt_vocab=tgt_vocab)
model = EncDec(enc, dec)
if use_cuda:
model = model.cuda()
monitor = TrainMonitor(model, len(train_sents), print_every=params.print_every,
plot_every=params.plot_every, save_plot_every=params.plot_every, model_every=params.model_every,
checkpoint_every=params.checkpoint_every, patience=params.patience, num_epochs=params.num_epochs,
output_path=params.OUTPUT_PATH, model_path=params.MODEL_PATH)
trainer = MTTrainer(model, monitor, optim_type='Adam', batch_size=params.batch_size,
beam_size=params.beam_size, learning_rate=0.0001)
trainer.train(train_sents, dev_sents_sorted, dev_sents_unsorted, tst_sents, src_vocab, tgt_vocab, params.num_epochs,
max_gen_length=params.max_gen_length, debug=args.debug, output_path=params.OUTPUT_PATH)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--debug", action='store_true')
parser.add_argument("-m", "--model", default=None)
parser.add_argument("-s", "--srcvocab", default=None)
parser.add_argument("-t", "--tgtvocab", default=None)
parser.add_argument("-n", "--maxnumsents", default=250000) #defaults to high enough for all
parser.add_argument("-c", "--config", default="params.py")
args = parser.parse_args()
main(args)