-
-
Notifications
You must be signed in to change notification settings - Fork 392
/
Copy pathcommon.py
197 lines (167 loc) · 6.64 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
# flake8: noqa
from typing import Dict, Optional
from datasets import DATASETS
import torch
from torch.utils.data import DataLoader
from catalyst import utils
from catalyst.contrib import nn, ResidualBlock
from catalyst.data import SelfSupervisedDatasetWrapper
def add_arguments(parser) -> None:
"""Function to add common arguments to argparse:
feature_dim: Feature dim for latent vector
temperature: Temperature used in softmax
batch_size: Number of images in each mini-batch
epochs: Number of sweeps over the dataset to train
num_workers: Number of workers to process a dataloader
logdir: Logs directory (tensorboard, weights, etc)
dataset: CIFAR-10, CIFAR-100 or STL10
learning-rate: Learning rate for optimizer
Args:
parser: argparser like object
"""
parser.add_argument(
"--dataset",
default="CIFAR-10",
type=str,
choices=DATASETS.keys(),
help="Dataset: CIFAR-10, CIFAR-100 or STL10",
)
parser.add_argument(
"--logdir",
default="./logdir",
type=str,
help="Logs directory (tensorboard, weights, etc)",
)
parser.add_argument(
"--epochs", default=1000, type=int, help="Number of sweeps over the dataset to train"
)
parser.add_argument(
"--num-workers", default=1, type=float, help="Number of workers to process a dataloader"
)
parser.add_argument(
"--batch-size", default=512, type=int, help="Number of images in each mini-batch"
)
parser.add_argument(
"--feature-dim", default=128, type=int, help="Feature dim for latent vector"
)
parser.add_argument(
"--temperature", default=0.5, type=float, help="Temperature used in softmax"
)
parser.add_argument(
"--learning-rate", default=0.001, type=float, help="Learning rate for optimizer"
)
# utils.boolean_flag(parser=parser, name="check", default=False)
utils.boolean_flag(parser=parser, name="verbose", default=False)
class ContrastiveModel(torch.nn.Module):
"""Contrastive model with projective head.
Args:
model: projective head for the train time
encoder: model for the future uses
"""
def __init__(self, model, encoder):
super(ContrastiveModel, self).__init__()
self.model = model
self.encoder = encoder
def forward(self, x):
"""Forward method.
Args:
x: input for the encoder
Returns:
(embeddings, projections)
"""
emb = self.encoder(x)
projection = self.model(emb)
return emb, projection
def get_loaders(
dataset: str, batch_size: int, num_workers: Optional[int]
) -> Dict[str, DataLoader]:
"""Init loaders based on parsed parametrs.
Args:
dataset: dataset for the experiment
batch_size: batch size for loaders
num_workers: number of workers to process loaders
Returns:
{"train":..., "valid":...}
"""
transforms = DATASETS[dataset]["train_transform"]
transform_original = DATASETS[dataset]["valid_transform"]
try:
train_data = DATASETS[dataset]["dataset"](root="data", train=True, download=True)
valid_data = DATASETS[dataset]["dataset"](root="data", train=False, download=True)
except:
train_data = DATASETS[dataset]["dataset"](root="data", split="train", download=True)
valid_data = DATASETS[dataset]["dataset"](root="data", split="test", download=True)
train_data = SelfSupervisedDatasetWrapper(
train_data,
transforms=transforms,
transform_original=transform_original,
)
valid_data = SelfSupervisedDatasetWrapper(
valid_data,
transforms=transforms,
transform_original=transform_original,
)
train_loader = DataLoader(train_data, batch_size=batch_size, num_workers=num_workers)
valid_loader = DataLoader(valid_data, batch_size=batch_size, num_workers=num_workers)
return {"train": train_loader, "valid": valid_loader}
def conv_block(in_channels, out_channels, pool=False):
layers = [
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
]
if pool:
layers.append(nn.MaxPool2d(2))
return nn.Sequential(*layers)
def resnet_mnist(in_size: int, in_channels: int, out_features: int, size: int = 16):
sz, sz2, sz4 = size, size * 2, size * 4
out_size = (((in_size // 16) * 16) ** 2 * 4) // size
return nn.Sequential(
conv_block(in_channels, sz),
conv_block(sz, sz2, pool=True),
ResidualBlock(nn.Sequential(conv_block(sz2, sz2), conv_block(sz2, sz2))),
conv_block(sz2, sz4, pool=True),
ResidualBlock(nn.Sequential(conv_block(sz4, sz4), conv_block(sz4, sz4))),
nn.Sequential(
nn.MaxPool2d(4), nn.Flatten(), nn.Dropout(0.2), nn.Linear(out_size, out_features)
),
)
def resnet9(in_size: int, in_channels: int, out_features: int, size: int = 16):
sz, sz2, sz4, sz8 = size, size * 2, size * 4, size * 8
assert in_size >= 32, "The graph is not valid for images with resolution lower then 32x32."
out_size = (((in_size // 32) * 32) ** 2 * 2) // size
return nn.Sequential(
conv_block(in_channels, sz),
conv_block(sz, sz2, pool=True),
ResidualBlock(nn.Sequential(conv_block(sz2, sz2), conv_block(sz2, sz2))),
conv_block(sz2, sz4, pool=True),
conv_block(sz4, sz8, pool=True),
ResidualBlock(nn.Sequential(conv_block(sz8, sz8), conv_block(sz8, sz8))),
nn.Sequential(
nn.MaxPool2d(4), nn.Flatten(), nn.Dropout(0.2), nn.Linear(out_size, out_features)
),
)
def get_contrastive_model(
in_size: int, in_channels: int, feature_dim: int, encoder_dim: int = 512, hidden_dim: int = 512
) -> ContrastiveModel:
"""Init contrastive model based on parsed parametrs.
Args:
in_size: size of an image (in_size x in_size)
in_channels: number of channels in an image
feature_dim: dimensinality of contrative projection
encoder_dim: dimensinality of encoder output
hidden_dim: dimensinality of encoder-contrative projection
Returns:
ContrstiveModel instance
"""
try:
encoder = resnet9(in_size=in_size, in_channels=in_channels, out_features=encoder_dim)
except:
encoder = resnet_mnist(in_size=in_size, in_channels=in_channels, out_features=encoder_dim)
projection_head = nn.Sequential(
nn.Linear(encoder_dim, hidden_dim, bias=False),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, feature_dim, bias=True),
)
model = ContrastiveModel(projection_head, encoder)
return model