-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
115 lines (90 loc) · 4.73 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import torch.nn as nn
import torch.nn.functional as F
import utils
from tqdm import tqdm
import os
import numpy as np
def train(model, model_name, train_iter, val_iter, SRC_TEXT, TRG_TEXT, anneal, num_epochs=20, gpu=False, lr=0.001,
kl_coef=1.0, min_kl=0.0, word_dpt=0.0, checkpoint=False):
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=30, factor=0.25, verbose=True, cooldown=6)
pad = TRG_TEXT.vocab.stoi['<pad>']
loss = nn.NLLLoss(size_average=True, ignore_index=pad)
cur_best = 0
for epoch in range(num_epochs):
model.train()
alpha = anneal(epoch, gpu=gpu) * kl_coef
train_nre = 0
train_kl_word = 0
train_kl_sent = 0
train_mu_dist = 0
train_p_scale = 0
train_q_scale = 0
for batch in tqdm(train_iter):
# note: word dropout masking works this way because the '<unk>' tokens happen to be 0 in both languages
src = batch.src
trg = batch.trg
trg_word_cnt = (trg != pad).float().sum() - trg.size(1)
re, kl, hidden, mu_prior, log_var_prior, mu_posterior, log_var_posterior = model(src, trg, word_dpt)
kl_word = kl.sum() / trg_word_cnt # KL by word
kl_sent = kl.sum() / len(kl) # KL by sent
nre = loss(re[:-1, :, :].view(-1, re.size(2)), trg[1:, :].view(-1))
neg_elbo = nre + alpha * kl_word.clamp(min_kl)
train_nre += nre.item()
train_kl_word += kl_word.item()
train_kl_sent += kl_sent.item()
train_mu_dist += (mu_prior - mu_posterior).abs().mean().item()
train_p_scale += log_var_prior.mul(0.5).exp().mean().item()
train_q_scale += log_var_posterior.mul(0.5).exp().mean().item()
optimizer.zero_grad()
neg_elbo.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
train_nre /= len(train_iter)
train_kl_word /= len(train_iter)
train_kl_sent /= len(train_iter)
train_mu_dist /= len(train_iter)
train_p_scale /= len(train_iter)
train_q_scale /= len(train_iter)
train_elbo = train_nre + train_kl_word
train_perp = np.exp(train_elbo)
val_perp, val_elbo, val_nre, val_kl_word, val_kl_sent, val_mu_dist, val_p_scale, val_q_scale = utils.eval_vae(model, val_iter, pad, gpu)
# greedy search
model.if_zero = False
bleu_greedy = utils.test_multibleu(model, val_iter, TRG_TEXT, k=1, gpu=gpu)
scheduler.step(bleu_greedy)
#scheduler.step(val_nre)
# greedy search - zeroed out latent vector
model.if_zero = True
bleu_zero = utils.test_multibleu(model, val_iter, TRG_TEXT, k=1, gpu=gpu)
results = 'Epoch: {}\n' \
'\tVALID PB: {:.4f} NELBO: {:.4f} RE: {:.4f} KL/W: {:.4f} KL/S: {:.4f}\n' \
'\tTRAIN PB: {:.4f} NELBO: {:.4f} RE: {:.4f} KL/W: {:.4f} KL/S: {:.4f}\n'\
'\tBLEU Greedy: {:.4f}\n\tBLEU Zero Greedy: {:.4f}\n'\
'\tVALID MU_DIST: {:.4f} P_SCALE: {:.4f} Q_SCALE: {:.4f}\n'\
'\tTRAIN MU_DIST: {:.4f} P_SCALE: {:.4f} Q_SCALE: {:.4f}'\
.format(epoch+1, val_perp, val_elbo, val_nre, val_kl_word, val_kl_sent,
np.exp(train_elbo), train_elbo, train_nre, train_kl_word, train_kl_sent, bleu_greedy, bleu_zero, val_mu_dist, val_p_scale, val_q_scale, train_mu_dist, train_p_scale, train_q_scale)
# if not (epoch + 1) % 5:
# model.if_zero = False
# bleu = utils.test_multibleu(model, val_iter, TRG_TEXT, gpu=gpu)
# results += '\n\tBLEU: {:.4f}'.format(bleu)
print(results)
if not (epoch + 1) % 1:
local_path = os.getcwd()
model_path = local_path + "/" + model_name
if not os.path.exists(model_path):
os.makedirs(model_path)
eval_file = model_path + "/" + "eval.txt"
if epoch == 0:
f = open(eval_file, "w")
f.write("{}".format(model))
f.write("Number of parameters: " + str(utils.count_parameters(model)) + "\n")
f.close()
with open(eval_file, "a") as f:
f.write("{}\n".format(results))
if (not (epoch + 1) % 2) and checkpoint and bleu_greedy > cur_best:
model_file = model_path + "/" + str(epoch + 1) + ".pt"
torch.save(model, model_file)
cur_best = bleu_greedy