-
Notifications
You must be signed in to change notification settings - Fork 24
/
resnet_baseline.py
71 lines (54 loc) · 2.3 KB
/
resnet_baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
from torch import nn
from .utils import ConvBlock, Conv1dSamePadding
class ResNetBaseline(nn.Module):
"""A PyTorch implementation of the ResNet Baseline
From https://arxiv.org/abs/1909.04939
Attributes
----------
sequence_length:
The size of the input sequence
mid_channels:
The 3 residual blocks will have as output channels:
[mid_channels, mid_channels * 2, mid_channels * 2]
num_pred_classes:
The number of output classes
"""
def __init__(self, in_channels: int, mid_channels: int = 64,
num_pred_classes: int = 1) -> None:
super().__init__()
# for easier saving and loading
self.input_args = {
'in_channels': in_channels,
'num_pred_classes': num_pred_classes
}
self.layers = nn.Sequential(*[
ResNetBlock(in_channels=in_channels, out_channels=mid_channels),
ResNetBlock(in_channels=mid_channels, out_channels=mid_channels * 2),
ResNetBlock(in_channels=mid_channels * 2, out_channels=mid_channels * 2),
])
self.final = nn.Linear(mid_channels * 2, num_pred_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
x = self.layers(x)
return self.final(x.mean(dim=-1))
class ResNetBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__()
channels = [in_channels, out_channels, out_channels, out_channels]
kernel_sizes = [8, 5, 3]
self.layers = nn.Sequential(*[
ConvBlock(in_channels=channels[i], out_channels=channels[i + 1],
kernel_size=kernel_sizes[i], stride=1) for i in range(len(kernel_sizes))
])
self.match_channels = False
if in_channels != out_channels:
self.match_channels = True
self.residual = nn.Sequential(*[
Conv1dSamePadding(in_channels=in_channels, out_channels=out_channels,
kernel_size=1, stride=1),
nn.BatchNorm1d(num_features=out_channels)
])
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
if self.match_channels:
return self.layers(x) + self.residual(x)
return self.layers(x)