This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 212
/
image_embedder_model.py
161 lines (132 loc) · 5.86 KB
/
image_embedder_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union
import pytorch_lightning
import torch
import torchvision
from pytorch_lightning.metrics import Accuracy
from pytorch_lightning.utilities.distributed import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn
from torch.nn import functional as F
from flash.core import Task
from flash.core.data import TaskDataPipeline
from flash.core.data.utils import _contains_any_tensor
from flash.vision.classification.data import _default_valid_transforms, _pil_loader, ImageClassificationData
from flash.vision.embedding.model_map import _load_model, _models
_resnet_backbone = lambda model: nn.Sequential(*list(model.children())[:-2]) # noqa: E731
_resnet_feats = lambda model: model.fc.in_features # noqa: E731
_backbones = {
"resnet18": (torchvision.models.resnet18, _resnet_backbone, _resnet_feats),
"resnet34": (torchvision.models.resnet34, _resnet_backbone, _resnet_feats),
"resnet50": (torchvision.models.resnet50, _resnet_backbone, _resnet_feats),
"resnet101": (torchvision.models.resnet101, _resnet_backbone, _resnet_feats),
"resnet152": (torchvision.models.resnet152, _resnet_backbone, _resnet_feats),
}
class ImageEmbedderDataPipeline(TaskDataPipeline):
def __init__(self, valid_transform: Optional[Callable] = _default_valid_transforms, loader: Callable = _pil_loader):
self._valid_transform = valid_transform
self._loader = loader
def before_collate(self, samples: Any) -> Any:
if _contains_any_tensor(samples):
return samples
if isinstance(samples, str):
samples = [samples]
if isinstance(samples, (list, tuple)) and all(isinstance(p, str) for p in samples):
outputs = []
for sample in samples:
output = self._loader(sample)
outputs.append(self._valid_transform(output))
return outputs
raise MisconfigurationException("The samples should either be a tensor, a list of paths or a path.")
class ImageEmbedder(Task):
"""Task that classifies images.
Args:
embedding_dim: Dimension of the embedded vector. None uses the default from the backbone
backbone: A model to use to extract image features.
pretrained: Use a pretrained backbone.
loss_fn: Loss function for training and finetuning, defaults to cross entropy.
optimizer: Optimizer to use for training and finetuning, defaults to `torch.optim.SGD`.
metrics: Metrics to compute for training and evaluation.
learning_rate: Learning rate to use for training, defaults to `1e-3`
pooling_fn: Function used to pool image to generate embeddings. (Default: torch.max)
Example::
from flash.vision import ImageEmbedder
embedder = ImageEmbedder(backbone='swav-imagenet')
image = torch.rand(32, 3, 32, 32)
embeddings = embedder(image)
"""
def __init__(
self,
embedding_dim: Optional[int] = None,
backbone: str = "swav-imagenet",
pretrained: bool = True,
loss_fn: Callable = F.cross_entropy,
optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD,
metrics: Union[Callable, Mapping, Sequence, None] = (Accuracy()),
learning_rate: float = 1e-3,
pooling_fn: Callable = torch.max
):
super().__init__(
model=None,
loss_fn=loss_fn,
optimizer=optimizer,
metrics=metrics,
learning_rate=learning_rate,
)
self.save_hyperparameters()
self.backbone_name = backbone
self.embedding_dim = embedding_dim
assert pooling_fn in [torch.mean, torch.max]
self.pooling_fn = pooling_fn
if backbone in _models:
config = _load_model(backbone)
self.backbone = config['model']
num_features = config['num_features']
elif backbone not in _backbones:
raise NotImplementedError(f"Backbone {backbone} is not yet supported")
else:
backbone_fn, split, num_feats = _backbones[backbone]
backbone = backbone_fn(pretrained=pretrained)
self.backbone = split(backbone)
num_features = num_feats(backbone)
if embedding_dim is None:
self.head = nn.Identity()
else:
self.head = nn.Sequential(
nn.Flatten(),
nn.Linear(num_features, embedding_dim),
)
rank_zero_warn('embedding_dim is not None. Remember to finetune first!')
def apply_pool(self, x):
if self.pooling_fn == torch.max:
# torch.max also returns argmax
x = self.pooling_fn(x, dim=-1)[0]
x = self.pooling_fn(x, dim=-1)[0]
else:
x = self.pooling_fn(x, dim=-1)
x = self.pooling_fn(x, dim=-1)
return x
def forward(self, x) -> Any:
x = self.backbone(x)
# bolts ssl models return lists
if isinstance(x, tuple):
x = x[-1]
if x.dim() == 4 and self.embedding_dim is not None:
x = self.apply_pool(x)
x = self.head(x)
return x
@staticmethod
def default_pipeline() -> ImageEmbedderDataPipeline:
return ImageEmbedderDataPipeline()