forked from mosheman5/DNP
-
Notifications
You must be signed in to change notification settings - Fork 1
/
unet.py
60 lines (52 loc) · 2.12 KB
/
unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.nn as nn
import torch.nn.functional as F
class Unet(nn.Module):
def __init__(self, nlayers=12, nefilters=24):
super(Unet, self).__init__()
print('unet')
self.num_layers = nlayers
self.nefilters = nefilters
filter_size = 15
merge_filter_size = 5
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
self.ebatch = nn.ModuleList()
self.dbatch = nn.ModuleList()
echannelin = [1] + [(i + 1) * nefilters for i in range(nlayers-1)]
echannelout = [(i + 1) * nefilters for i in range(nlayers)]
dchannelout = echannelout[::-1]
dchannelin = [dchannelout[0]*2]+[(i) * nefilters + (i - 1) * nefilters for i in range(nlayers,1,-1)]
for i in range(self.num_layers):
self.encoder.append(nn.Conv1d(echannelin[i],echannelout[i],filter_size,padding=filter_size//2))
self.decoder.append(nn.Conv1d(dchannelin[i],dchannelout[i],merge_filter_size,padding=merge_filter_size//2))
self.ebatch.append(nn.BatchNorm1d(echannelout[i]))
self.dbatch.append(nn.BatchNorm1d(dchannelout[i]))
self.middle = nn.Sequential(
nn.Conv1d(echannelout[-1],echannelout[-1],filter_size,padding=filter_size//2),
nn.BatchNorm1d(echannelout[-1]),
nn.LeakyReLU(0.1)
)
self.out = nn.Sequential(
nn.Conv1d(nefilters + 1, 1, 1),
nn.Tanh()
)
def forward(self,x):
encoder = list()
input = x
for i in range(self.num_layers):
x = self.encoder[i](x)
x = self.ebatch[i](x)
x = F.leaky_relu(x,0.1)
encoder.append(x)
x = x[:,:,::2]
x = self.middle(x)
for i in range(self.num_layers):
x = F.upsample(x,scale_factor=2,mode='linear', align_corners=False)
x = torch.cat([x,encoder[self.num_layers - i - 1]],dim=1)
x = self.decoder[i](x)
x = self.dbatch[i](x)
x = F.leaky_relu(x,0.1)
x = torch.cat([x,input],dim=1)
x = self.out(x)
return x