-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathinception.py
134 lines (114 loc) · 5.58 KB
/
inception.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
from torch import nn
from .utils import Conv1dSamePadding
from typing import cast, Union, List
class InceptionModel(nn.Module):
"""A PyTorch implementation of the InceptionTime model.
From https://arxiv.org/abs/1909.04939
Attributes
----------
num_blocks:
The number of inception blocks to use. One inception block consists
of 3 convolutional layers, (optionally) a bottleneck and (optionally) a residual
connector
in_channels:
The number of input channels (i.e. input.shape[-1])
out_channels:
The number of "hidden channels" to use. Can be a list (for each block) or an
int, in which case the same value will be applied to each block
bottleneck_channels:
The number of channels to use for the bottleneck. Can be list or int. If 0, no
bottleneck is applied
kernel_sizes:
The size of the kernels to use for each inception block. Within each block, each
of the 3 convolutional layers will have kernel size
`[kernel_size // (2 ** i) for i in range(3)]`
num_pred_classes:
The number of output classes
"""
def __init__(self, num_blocks: int, in_channels: int, out_channels: Union[List[int], int],
bottleneck_channels: Union[List[int], int], kernel_sizes: Union[List[int], int],
use_residuals: Union[List[bool], bool, str] = 'default',
num_pred_classes: int = 1
) -> None:
super().__init__()
# for easier saving and loading
self.input_args = {
'num_blocks': num_blocks,
'in_channels': in_channels,
'out_channels': out_channels,
'bottleneck_channels': bottleneck_channels,
'kernel_sizes': kernel_sizes,
'use_residuals': use_residuals,
'num_pred_classes': num_pred_classes
}
channels = [in_channels] + cast(List[int], self._expand_to_blocks(out_channels,
num_blocks))
bottleneck_channels = cast(List[int], self._expand_to_blocks(bottleneck_channels,
num_blocks))
kernel_sizes = cast(List[int], self._expand_to_blocks(kernel_sizes, num_blocks))
if use_residuals == 'default':
use_residuals = [True if i % 3 == 2 else False for i in range(num_blocks)]
use_residuals = cast(List[bool], self._expand_to_blocks(
cast(Union[bool, List[bool]], use_residuals), num_blocks)
)
self.blocks = nn.Sequential(*[
InceptionBlock(in_channels=channels[i], out_channels=channels[i + 1],
residual=use_residuals[i], bottleneck_channels=bottleneck_channels[i],
kernel_size=kernel_sizes[i]) for i in range(num_blocks)
])
# a global average pooling (i.e. mean of the time dimension) is why
# in_features=channels[-1]
self.linear = nn.Linear(in_features=channels[-1], out_features=num_pred_classes)
@staticmethod
def _expand_to_blocks(value: Union[int, bool, List[int], List[bool]],
num_blocks: int) -> Union[List[int], List[bool]]:
if isinstance(value, list):
assert len(value) == num_blocks, \
f'Length of inputs lists must be the same as num blocks, ' \
f'expected length {num_blocks}, got {len(value)}'
else:
value = [value] * num_blocks
return value
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
x = self.blocks(x).mean(dim=-1) # the mean is the global average pooling
return self.linear(x)
class InceptionBlock(nn.Module):
"""An inception block consists of an (optional) bottleneck, followed
by 3 conv1d layers. Optionally residual
"""
def __init__(self, in_channels: int, out_channels: int,
residual: bool, stride: int = 1, bottleneck_channels: int = 32,
kernel_size: int = 41) -> None:
assert kernel_size > 3, "Kernel size must be strictly greater than 3"
super().__init__()
self.use_bottleneck = bottleneck_channels > 0
if self.use_bottleneck:
self.bottleneck = Conv1dSamePadding(in_channels, bottleneck_channels,
kernel_size=1, bias=False)
kernel_size_s = [kernel_size // (2 ** i) for i in range(3)]
start_channels = bottleneck_channels if self.use_bottleneck else in_channels
channels = [start_channels] + [out_channels] * 3
self.conv_layers = nn.Sequential(*[
Conv1dSamePadding(in_channels=channels[i], out_channels=channels[i + 1],
kernel_size=kernel_size_s[i], stride=stride, bias=False)
for i in range(len(kernel_size_s))
])
self.batchnorm = nn.BatchNorm1d(num_features=channels[-1])
self.relu = nn.ReLU()
self.use_residual = residual
if residual:
self.residual = nn.Sequential(*[
Conv1dSamePadding(in_channels=in_channels, out_channels=out_channels,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm1d(out_channels),
nn.ReLU()
])
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
org_x = x
if self.use_bottleneck:
x = self.bottleneck(x)
x = self.conv_layers(x)
if self.use_residual:
x = x + self.residual(org_x)
return x