-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnet.py
68 lines (62 loc) · 2.25 KB
/
net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
class Net(torch.nn.Module):
def __init__(self, n_input, n_output, n_hidden):
super(Net, self).__init__()
# variable cooling rate
self.r = torch.nn.Parameter(data=torch.tensor([0.]))
# network
self.net = torch.nn.Sequential(
torch.nn.Linear(n_input, n_hidden),
torch.nn.ReLU(),
torch.nn.Linear(n_hidden, n_hidden),
torch.nn.ReLU(),
torch.nn.Linear(n_hidden, n_hidden),
torch.nn.ReLU(),
torch.nn.Linear(n_hidden, n_hidden),
torch.nn.ReLU(),
torch.nn.Linear(n_hidden, n_output),
)
def forward(self, x):
return self.net(x)
def fit(self, X, y, n_epochs, lr, loss_fn, loss2=None, loss2_weight=0.1, test_set=None):
# convert to tensors
X = torch.tensor(X, dtype=torch.float32).reshape(len(X), -1)
y = torch.tensor(y, dtype=torch.float32).reshape(len(y), -1)
# optimizer
optimizer = torch.optim.Adam(self.parameters(), lr=lr)
# train
self.train()
losses = []
preds = [self.predict(test_set).detach().numpy().flatten()]
for epoch in range(n_epochs):
optimizer.zero_grad()
output = self.forward(X)
loss = loss_fn(y, output)
if loss2 is not None:
loss += loss2_weight * loss2(self)
loss.backward()
optimizer.step()
losses.append(loss.item())
if epoch % 4000 == 0:
print(f'Epoch {epoch}, loss {loss.item()}')
if (epoch % 100 == 0 or epoch == n_epochs-1) and test_set is not None:
pred = self.predict(test_set).detach().numpy().flatten()
preds.append(pred)
return losses, preds
def predict(self, X):
# convert to tensor
X = torch.tensor(X, dtype=torch.float32).reshape(len(X), -1)
# predict
self.eval()
output = self.forward(X)
return output
def grad(outputs, inputs):
"""
Computes the partial derivative of an output with respect to an input.
"""
return torch.autograd.grad(
outputs,
inputs,
grad_outputs=torch.ones_like(outputs),
create_graph=True
)