-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_unet.py
72 lines (51 loc) · 2.01 KB
/
train_unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""Train a simple Unet"""
import torch
from torchvision.transforms import Compose, Normalize, ToTensor
from nf_dataset import NFDataset
data_transform = Compose([ ToTensor(),
Normalize([0.], [300.]),
Normalize([.5], [.5]),
])
target_transform = ToTensor()
import os, glob
subject_folders = sorted([os.path.basename(folder) for folder in glob.glob("/home/michael/nf_dataset")])
test_folders = subject_folders[-10:]
train_folders = subject_folders[:-10]
nfdataset = NFDataset("/home/michael/nf_dataset",
data_transform=data_transform,
exclude_subjects=test_folders)
# target_transform=target_transform)
nfdataset_test = NFDataset("/home/michael/nf_dataset",
data_transform=data_transform,
exclude_subjects=train_folders)
#
sampler = torch.utils.data.sampler.SubsetRandomSampler(
nfdataset.positive_counts)
dataloader = torch.utils.data.DataLoader(
nfdataset, sampler=sampler, batch_size=32)
from unet.unet_model import UNet
unet = UNet(1, 1)
n_epochs = 10
n_samples_per_epoch = 100000
all_epoch_avg_losses = []
unet = unet.cuda()
optimizer = torch.optim.Adam(unet.parameters())
import sys
import numpy as np
for e in range(n_epochs):
losses = []
for (x, y), ii in zip(dataloader, range(n_samples_per_epoch)):
x = x.cuda()
y = y.cuda()[..., 0]
yy = y.type(torch.float) * 2 - 1
p = unet.forward(x)[:, 0]
loss = -torch.nn.LogSigmoid()(p * yy).mean()
losses.append(loss.detach().cpu().numpy())
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"{ii:05d}/{n_samples_per_epoch} {losses[-1]:0.3f} {np.mean(losses[-10:]):0.3f} {np.mean(losses):0.3f}", end="\r")
sys.stdout.flush()
torch.save(unet, f"unet0_cp{e:02d}.th")
print("\n")
all_epoch_avg_losses.append(np.mean(losses))