forked from SpaceAbleOrg/symplectic-hnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
155 lines (119 loc) · 5.51 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import sys
import copy
import numpy as np
import torch
from timeit import default_timer as timer
from datetime import timedelta
# from torch.utils.tensorboard import SummaryWriter
from model.loss import OneStepLoss
from model.hnn import HNN
from model.args import get_args
from model.data import choose_data
from utils import save_path, to_pickle, from_pickle
# This function is generic, but needs to run in a top-level file to setup the path variables
# and define the save_directory properly.
def setup(args, save_dir_prefix='/experiment-'):
# Setup directory of this file as working (save) directory
this_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(this_dir)
sys.path.append(parent_dir)
# Set the save directory if nothing is given
if not args.save_dir:
args.save_dir = this_dir + save_dir_prefix + args.name
# Store data_class directly in args for future access, and dimension for future convenience (eg of loss functions)
args.data_class = choose_data(args.name)
args.dim = args.data_class.dimension()
# Set random seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
return args
def load_data(args):
data_path = save_path(args, ext='shnndata', incl_loss=False)
if args.new_data or not os.path.exists(data_path):
if args.verbose:
print("Generating a new data set...")
data_loader = args.data_class(args.h, args.noise)
data = data_loader.get_dataset(seed=args.seed, samples=args.data_samples,
test_split=args.test_split, print_args=args)
os.makedirs(args.save_dir) if not os.path.exists(args.save_dir) else None
to_pickle(data, data_path)
else:
if args.verbose:
print("Loading the existing data set...")
data = from_pickle(data_path)
return data
def train(model, data, args):
# Create a standard optimizer
# optim = torch.optim.Adam(model.parameters(), args.learn_rate, weight_decay=1e-4)
optim = torch.optim.AdamW(model.parameters(), args.learn_rate, betas=(0.9, 0.999), eps=1e-08,
weight_decay=0.01, amsgrad=True)
# Load the symplectic (or not) loss function
loss_fct = OneStepLoss(args) # Choosing the actual loss_type is hidden in this constructor
# Prepare objects from dataset dictionary
x = torch.tensor(data['coords'], requires_grad=True, dtype=torch.float32) # shape (set_size, dim)
test_x = torch.tensor(data['test_coords'], requires_grad=True, dtype=torch.float32) # shape (set_size, dim)
t = data['t']
test_t = data['test_t']
# DO VANILLA TRAINING LOOP
if args.verbose:
print("Begin training loop...")
best_model, best_test_loss = model, np.infty
stats = {'train_loss': [], 'test_loss': []}
start_time = timer()
# writer = SummaryWriter()
for step in range(args.epochs + 1):
# Use stochastic gradient descent (SGD) with args.batch_size – TODO
# for ixs in torch.split(torch.arange(x.shape[0]), args.batch_size):
# ...
# loss = loss_fct(model, x[ixs], t[ixs])
# ...
# train step, find loss and optimize
model.train()
loss = loss_fct(model, x, t)
loss.backward()
optim.step()
optim.zero_grad()
# run test data
model.eval()
train_loss_val = loss.cpu().detach()
test_loss_val = loss_fct(model, test_x, test_t).cpu().detach()
if step > args.epochs/2 and test_loss_val < best_test_loss:
best_model = copy.deepcopy(model)
best_test_loss = test_loss_val
# logging with tensorboard
# writer.add_scalar("Loss/Train", train_loss_val, step)
# writer.add_scalar("Loss/Test", test_loss_val, step)
# logging manually
stats['train_loss'].append(train_loss_val)
stats['test_loss'].append(test_loss_val)
if args.verbose and step % args.print_every == 0:
print(f"step {step}, train_loss {train_loss_val:.4e}, test_loss {test_loss_val:.4e}",
f" ({timedelta(seconds=int(timer()-start_time))} (h/m/s) elapsed)")
# Final evaluation using the best_model
train_dist = loss_fct(best_model, x, t, return_dist=True).cpu().detach().numpy()
test_dist = loss_fct(best_model, test_x, test_t, return_dist=True).cpu().detach().numpy()
print('Final train loss {:.4e} +/- {:.4e}\nFinal test loss {:.4e} +/- {:.4e}'
.format(train_dist.mean().item(), train_dist.std().item() / np.sqrt(train_dist.shape[0]),
test_dist.mean().item(), test_dist.std().item() / np.sqrt(test_dist.shape[0])))
return best_model, stats
def train_main(args):
# SETUP ENV AND ARGUMENTS
args = setup(args)
# CREATE THE EMPTY MODEL
model = HNN.create(args)
# LOAD DATA SET
data = load_data(args)
# RUN THE MAIN FUNCTION TO TRAIN THE MODEL
model, loss_log = train(model, data, args)
# SAVE
os.makedirs(args.save_dir) if not os.path.exists(args.save_dir) else None
torch.save({'args': args, 'model': model.state_dict(), 'stats': loss_log}, save_path(args))
def train_if_missing(args, save_dir_prefix='/experiment-'):
args = setup(args, save_dir_prefix=save_dir_prefix)
if not os.path.exists(save_path(args)):
train_main(args)
if __name__ == "__main__":
""" This file can be run with one well-defined set of arguments. To run for several configurations at once,
please consult the parallelize.py file. """
train_main(get_args())