-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathtrain.py
83 lines (70 loc) · 3.24 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
from torch.utils.data import DataLoader
import numpy as np
from setting import Setting
from trainer import FlashbackTrainer
from dataloader import PoiDataloader
from dataset import Split
from network import create_h0_strategy
from evaluation import Evaluation
'''
Main train script to invoke from commandline.
'''
### parse settings ###
setting = Setting()
setting.parse()
print(setting)
### load dataset ###
poi_loader = PoiDataloader(setting.max_users, setting.min_checkins)
poi_loader.read(setting.dataset_file)
dataset = poi_loader.create_dataset(setting.sequence_length, setting.batch_size, Split.TRAIN)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
dataset_test = poi_loader.create_dataset(setting.sequence_length, setting.batch_size, Split.TEST)
dataloader_test = DataLoader(dataset_test, batch_size=1, shuffle=False)
assert setting.batch_size < poi_loader.user_count(), 'batch size must be lower than the amount of available users'
### create flashback trainer ###
trainer = FlashbackTrainer(setting.lambda_t, setting.lambda_s)
h0_strategy = create_h0_strategy(setting.hidden_dim, setting.is_lstm)
trainer.prepare(poi_loader.locations(), poi_loader.user_count(), setting.hidden_dim, setting.rnn_factory, setting.device)
evaluation_test = Evaluation(dataset_test, dataloader_test, poi_loader.user_count(), h0_strategy, trainer, setting)
print('{} {}'.format(trainer, setting.rnn_factory))
### training loop ###
optimizer = torch.optim.Adam(trainer.parameters(), lr=setting.learning_rate, weight_decay=setting.weight_decay)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20,40,60,80], gamma=0.2)
for e in range(setting.epochs):
h = h0_strategy.on_init(setting.batch_size, setting.device)
dataset.shuffle_users() # shuffle users before each epoch!
losses = []
for i, (x, t, s, y, y_t, y_s, reset_h, active_users) in enumerate(dataloader):
# reset hidden states for newly added users
for j, reset in enumerate(reset_h):
if reset:
if setting.is_lstm:
hc = h0_strategy.on_reset(active_users[0][j])
h[0][0, j] = hc[0]
h[1][0, j] = hc[1]
else:
h[0, j] = h0_strategy.on_reset(active_users[0][j])
x = x.squeeze().to(setting.device)
t = t.squeeze().to(setting.device)
s = s.squeeze().to(setting.device)
y = y.squeeze().to(setting.device)
y_t = y_t.squeeze().to(setting.device)
y_s = y_s.squeeze().to(setting.device)
active_users = active_users.to(setting.device)
optimizer.zero_grad()
loss, h = trainer.loss(x, t, s, y, y_t, y_s, h, active_users)
loss.backward(retain_graph=True)
losses.append(loss.item())
optimizer.step()
# schedule learning rate:
scheduler.step()
# statistics:
if (e+1) % 1 == 0:
epoch_loss = np.mean(losses)
print(f'Epoch: {e+1}/{setting.epochs}')
print(f'Used learning rate: {scheduler.get_lr()[0]}')
print(f'Avg Loss: {epoch_loss}')
if (e+1) % setting.validate_epoch == 0:
print(f'~~~ Test Set Evaluation (Epoch: {e+1}) ~~~')
evaluation_test.evaluate()