-
Notifications
You must be signed in to change notification settings - Fork 1
/
vanilla_train.py
79 lines (58 loc) · 2.72 KB
/
vanilla_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import torch.nn as nn
import torch.nn.functional as F
import utils
from tqdm import tqdm
import os
import numpy as np
def train(model, model_name, train_iter, val_iter, SRC_TEXT, TRG_TEXT, num_epochs=20, gpu=False, lr=0.001, weight_decay=0, checkpoint=False):
# optimizer and scheduler:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=30, factor=0.25, verbose=True, cooldown=6)
pad = TRG_TEXT.vocab.stoi['<pad>']
loss = nn.NLLLoss(size_average=True, ignore_index=pad)
cur_best = 0
for epoch in range(num_epochs):
model.train()
train_nll = 0
for batch in tqdm(train_iter):
src, trg = (batch.src.cuda(), batch.trg.cuda()) if gpu else (batch.src, batch.trg)
ll, hidden = model(src, trg)
nll = loss(ll[:-1, :, :].view(-1, ll.size(2)), trg[1:, :].view(-1))
train_nll += nll.item()
optimizer.zero_grad()
nll.backward()
torch.nn.utils.clip_grad_norm(model.parameters(), 1.0)
optimizer.step()
train_nll /= len(train_iter)
train_perp = np.exp(train_nll)
val_perp, val_nll = utils.eval_seq2seq(model, val_iter, pad, gpu)
# greedy search
bleu_greedy = utils.test_multibleu(model, val_iter, TRG_TEXT, k=1, gpu=gpu)
scheduler.step(bleu_greedy)
results = 'Epoch: {}\n' \
'\tVALID PPL: {:.4f} NLL: {:.4f}\n'\
'\tTRAIN PPL: {:.4f} NLL: {:.4f}\n'\
'\tBLEU Greedy: {:.4f}'\
.format(epoch+1, val_perp, val_nll, train_perp, train_nll, bleu_greedy)
if not (epoch + 1) % 2:
bleu = utils.test_multibleu(model, val_iter, TRG_TEXT, gpu=gpu)
results += '\n\tBLEU: {:.4f}'.format(bleu)
print(results)
if not (epoch + 1) % 1:
local_path = os.getcwd()
model_path = local_path + "/" + model_name
if not os.path.exists(model_path):
os.makedirs(model_path)
eval_file = model_path + "/" + "eval.txt"
if epoch == 0:
f = open(eval_file, "w")
f.write("{}".format(model))
f.write("Number of parameters: " + str(utils.count_parameters(model)) + "\n")
f.close()
with open(eval_file, "a") as f:
f.write("{}\n".format(results))
if (not (epoch + 1) % 2) and checkpoint and bleu_greedy > cur_best:
model_file = model_path + "/" + str(epoch + 1) + ".pt"
torch.save(model, model_file)
cur_best = bleu_greedy