This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Ethan Harris <[email protected]>
- Loading branch information
1 parent
dfe8854
commit d190c76
Showing
25 changed files
with
1,098 additions
and
141 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from flash.core.registry import FlashRegistry # noqa: F401 | ||
from flash.image.embedding.backbones.vissl_backbones import register_vissl_backbones # noqa: F401 | ||
|
||
IMAGE_EMBEDDER_BACKBONES = FlashRegistry("embedder_backbones") | ||
register_vissl_backbones(IMAGE_EMBEDDER_BACKBONES) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import torch.nn as nn | ||
|
||
from flash.core.registry import FlashRegistry | ||
from flash.core.utilities.imports import _VISSL_AVAILABLE | ||
|
||
if _VISSL_AVAILABLE: | ||
from vissl.config.attr_dict import AttrDict | ||
from vissl.models.model_helpers import RESNET_NORM_LAYER | ||
from vissl.models.trunks import MODEL_TRUNKS_REGISTRY | ||
|
||
from flash.image.embedding.vissl.adapter import VISSLAdapter | ||
else: | ||
RESNET_NORM_LAYER = object | ||
|
||
|
||
def vision_transformer( | ||
image_size: int = 224, | ||
patch_size: int = 16, | ||
hidden_dim: int = 384, | ||
num_layers: int = 12, | ||
num_heads: int = 6, | ||
mlp_dim: int = 1532, | ||
dropout_rate: float = 0, | ||
attention_dropout_rate: float = 0, | ||
drop_path_rate: float = 0, | ||
qkv_bias: bool = True, | ||
qk_scale: bool = False, | ||
classifier: str = "token", | ||
**kwargs, | ||
) -> nn.Module: | ||
|
||
cfg = VISSLAdapter.get_model_config_template() | ||
cfg.TRUNK = AttrDict( | ||
{ | ||
"NAME": "vision_transformer", | ||
"VISION_TRANSFORMERS": AttrDict( | ||
{ | ||
"IMAGE_SIZE": image_size, | ||
"PATCH_SIZE": patch_size, | ||
"HIDDEN_DIM": hidden_dim, | ||
"NUM_LAYERS": num_layers, | ||
"NUM_HEADS": num_heads, | ||
"MLP_DIM": mlp_dim, | ||
"DROPOUT_RATE": dropout_rate, | ||
"ATTENTION_DROPOUT_RATE": attention_dropout_rate, | ||
"DROP_PATH_RATE": drop_path_rate, | ||
"QKV_BIAS": qkv_bias, | ||
"QK_SCALE": qk_scale, | ||
"CLASSIFIER": classifier, | ||
} | ||
), | ||
} | ||
) | ||
|
||
trunk = MODEL_TRUNKS_REGISTRY["vision_transformer"](cfg, model_name="vision_transformer") | ||
trunk.model_config = cfg | ||
|
||
return trunk, trunk.num_features | ||
|
||
|
||
def resnet( | ||
depth: int = 50, | ||
width_multiplier: int = 1, | ||
norm: RESNET_NORM_LAYER = None, | ||
groupnorm_groups: int = 32, | ||
standardize_convolutions: bool = False, | ||
groups: int = 1, | ||
zero_init_residual: bool = False, | ||
width_per_group: int = 64, | ||
layer4_stride: int = 2, | ||
**kwargs, | ||
) -> nn.Module: | ||
if norm is None: | ||
norm = RESNET_NORM_LAYER.BatchNorm | ||
cfg = VISSLAdapter.get_model_config_template() | ||
cfg.TRUNK = AttrDict( | ||
{ | ||
"NAME": "resnet", | ||
"RESNETS": AttrDict( | ||
{ | ||
"DEPTH": depth, | ||
"WIDTH_MULTIPLIER": width_multiplier, | ||
"NORM": norm, | ||
"GROUPNORM_GROUPS": groupnorm_groups, | ||
"STANDARDIZE_CONVOLUTIONS": standardize_convolutions, | ||
"GROUPS": groups, | ||
"ZERO_INIT_RESIDUAL": zero_init_residual, | ||
"WIDTH_PER_GROUP": width_per_group, | ||
"LAYER4_STRIDE": layer4_stride, | ||
} | ||
), | ||
} | ||
) | ||
|
||
trunk = MODEL_TRUNKS_REGISTRY["resnet"](cfg, model_name="resnet") | ||
trunk.model_config = cfg | ||
|
||
return trunk, 2048 | ||
|
||
|
||
def register_vissl_backbones(register: FlashRegistry): | ||
if _VISSL_AVAILABLE: | ||
for backbone in (vision_transformer, resnet): | ||
register(backbone) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from flash.core.registry import FlashRegistry # noqa: F401 | ||
from flash.image.embedding.heads.vissl_heads import register_vissl_heads # noqa: F401 | ||
|
||
IMAGE_EMBEDDER_HEADS = FlashRegistry("embedder_heads") | ||
register_vissl_heads(IMAGE_EMBEDDER_HEADS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import List, Union | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from flash.core.registry import FlashRegistry | ||
from flash.core.utilities.imports import _VISSL_AVAILABLE | ||
|
||
if _VISSL_AVAILABLE: | ||
from vissl.config.attr_dict import AttrDict | ||
from vissl.models.heads import MODEL_HEADS_REGISTRY, register_model_head | ||
|
||
from flash.image.embedding.vissl.adapter import VISSLAdapter | ||
else: | ||
AttrDict = object | ||
|
||
|
||
class SimCLRHead(nn.Module): | ||
def __init__( | ||
self, | ||
model_config: AttrDict, | ||
dims: List[int] = [2048, 2048, 128], | ||
use_bn: bool = True, | ||
**kwargs, | ||
) -> nn.Module: | ||
super().__init__() | ||
|
||
self.model_config = model_config | ||
self.dims = dims | ||
self.use_bn = use_bn | ||
|
||
self.clf = self.create_mlp() | ||
|
||
def create_mlp(self): | ||
layers = [] | ||
last_dim = self.dims[0] | ||
|
||
for dim in self.dims[1:-1]: | ||
layers.append(nn.Linear(last_dim, dim)) | ||
|
||
if self.use_bn: | ||
layers.append( | ||
nn.BatchNorm1d( | ||
dim, | ||
eps=self.model_config.HEAD.BATCHNORM_EPS, | ||
momentum=self.model_config.HEAD.BATCHNORM_MOMENTUM, | ||
) | ||
) | ||
|
||
layers.append(nn.ReLU(inplace=True)) | ||
last_dim = dim | ||
|
||
layers.append(nn.Linear(last_dim, self.dims[-1])) | ||
return nn.Sequential(*layers) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
return self.clf(x) | ||
|
||
|
||
if _VISSL_AVAILABLE: | ||
SimCLRHead = register_model_head("simclr_head")(SimCLRHead) | ||
|
||
|
||
def simclr_head( | ||
dims: List[int] = [2048, 2048, 128], | ||
use_bn: bool = True, | ||
**kwargs, | ||
) -> nn.Module: | ||
cfg = VISSLAdapter.get_model_config_template() | ||
head_kwargs = { | ||
"dims": dims, | ||
"use_bn": use_bn, | ||
} | ||
|
||
cfg.HEAD.PARAMS.append(["simclr_head", head_kwargs]) | ||
|
||
head = MODEL_HEADS_REGISTRY["simclr_head"](cfg, **head_kwargs) | ||
head.model_config = cfg | ||
|
||
return head | ||
|
||
|
||
def swav_head( | ||
dims: List[int] = [2048, 2048, 128], | ||
use_bn: bool = True, | ||
num_clusters: Union[int, List[int]] = [3000], | ||
use_bias: bool = True, | ||
return_embeddings: bool = True, | ||
skip_last_bn: bool = True, | ||
normalize_feats: bool = True, | ||
activation_name: str = "ReLU", | ||
use_weight_norm_prototypes: bool = False, | ||
**kwargs, | ||
) -> nn.Module: | ||
cfg = VISSLAdapter.get_model_config_template() | ||
head_kwargs = { | ||
"dims": dims, | ||
"use_bn": use_bn, | ||
"num_clusters": [num_clusters] if isinstance(num_clusters, int) else num_clusters, | ||
"use_bias": use_bias, | ||
"return_embeddings": return_embeddings, | ||
"skip_last_bn": skip_last_bn, | ||
"normalize_feats": normalize_feats, | ||
"activation_name": activation_name, | ||
"use_weight_norm_prototypes": use_weight_norm_prototypes, | ||
} | ||
|
||
cfg.HEAD.PARAMS.append(["swav_head", head_kwargs]) | ||
|
||
head = MODEL_HEADS_REGISTRY["swav_head"](cfg, **head_kwargs) | ||
head.model_config = cfg | ||
|
||
return head | ||
|
||
|
||
def barlow_twins_head(**kwargs) -> nn.Module: | ||
return simclr_head(dims=[2048, 8192, 8192, 8192], **kwargs) | ||
|
||
|
||
def dino_head(**kwargs) -> nn.Module: | ||
return swav_head( | ||
dims=[384, 2048, 2048, 256], | ||
use_bn=False, | ||
return_embeddings=False, | ||
activation_name="GELU", | ||
num_clusters=[65536], | ||
use_weight_norm_prototypes=True, | ||
**kwargs, | ||
) | ||
|
||
|
||
def register_vissl_heads(register: FlashRegistry): | ||
for ssl_head in (swav_head, simclr_head, dino_head, barlow_twins_head): | ||
register(ssl_head) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from flash.core.registry import FlashRegistry # noqa: F401 | ||
from flash.image.embedding.losses.vissl_losses import register_vissl_losses # noqa: F401 | ||
|
||
IMAGE_EMBEDDER_LOSS_FUNCTIONS = FlashRegistry("embedder_losses") | ||
register_vissl_losses(IMAGE_EMBEDDER_LOSS_FUNCTIONS) |
Oops, something went wrong.