Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
VISSL initial integration (#682)
Browse files Browse the repository at this point in the history
Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
ananyahjha93 and ethanwharris authored Sep 21, 2021
1 parent dfe8854 commit d190c76
Show file tree
Hide file tree
Showing 25 changed files with 1,098 additions and 141 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ jobs:
- name: Install vissl
if: matrix.topic[1] == 'image_extras'
run: |
pip install git+https://github.com/facebookresearch/vissl.git@master
pip install git+https://github.com/facebookresearch/ClassyVision.git
pip install git+https://github.com/facebookresearch/vissl.git
- name: Install graph test dependencies
if: matrix.topic[0] == 'graph'
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support `learn2learn` training_strategy for `ImageClassifier` ([#737](https://github.com/PyTorchLightning/lightning-flash/pull/737))

- Added `vissl` training_strategies for `ImageEmbedder` ([#682](https://github.com/PyTorchLightning/lightning-flash/pull/682))

### Changed

- Changed the default `num_workers` on linux to `0` (matching the default for other OS) ([#759](https://github.com/PyTorchLightning/lightning-flash/pull/759))
Expand Down
6 changes: 5 additions & 1 deletion flash/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,11 @@ def _register_function(
if not callable(fn):
raise MisconfigurationException(f"You can only register a callable, found: {fn}")

name = name or fn.__name__
if name is None:
if hasattr(fn, "func"):
name = fn.func.__name__
else:
name = fn.__name__

if self._verbose:
rank_zero_info(f"Registering: {fn.__name__} function with name: {name} and metadata: {metadata}")
Expand Down
1 change: 1 addition & 0 deletions flash/core/utilities/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,4 @@ def __str__(self):
_FAIRSEQ = Provider("PyTorch/fairseq", "https://github.com/pytorch/fairseq")
_OPEN3D_ML = Provider("Intelligent Systems Lab Org/Open3D-ML", "https://github.com/isl-org/Open3D-ML")
_PYTORCHVIDEO = Provider("Facebook Research/PyTorchVideo", "https://github.com/facebookresearch/pytorchvideo")
_VISSL = Provider("Facebook Research/vissl", "https://github.com/facebookresearch/vissl")
5 changes: 5 additions & 0 deletions flash/image/embedding/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from flash.core.registry import FlashRegistry # noqa: F401
from flash.image.embedding.backbones.vissl_backbones import register_vissl_backbones # noqa: F401

IMAGE_EMBEDDER_BACKBONES = FlashRegistry("embedder_backbones")
register_vissl_backbones(IMAGE_EMBEDDER_BACKBONES)
117 changes: 117 additions & 0 deletions flash/image/embedding/backbones/vissl_backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.nn as nn

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _VISSL_AVAILABLE

if _VISSL_AVAILABLE:
from vissl.config.attr_dict import AttrDict
from vissl.models.model_helpers import RESNET_NORM_LAYER
from vissl.models.trunks import MODEL_TRUNKS_REGISTRY

from flash.image.embedding.vissl.adapter import VISSLAdapter
else:
RESNET_NORM_LAYER = object


def vision_transformer(
image_size: int = 224,
patch_size: int = 16,
hidden_dim: int = 384,
num_layers: int = 12,
num_heads: int = 6,
mlp_dim: int = 1532,
dropout_rate: float = 0,
attention_dropout_rate: float = 0,
drop_path_rate: float = 0,
qkv_bias: bool = True,
qk_scale: bool = False,
classifier: str = "token",
**kwargs,
) -> nn.Module:

cfg = VISSLAdapter.get_model_config_template()
cfg.TRUNK = AttrDict(
{
"NAME": "vision_transformer",
"VISION_TRANSFORMERS": AttrDict(
{
"IMAGE_SIZE": image_size,
"PATCH_SIZE": patch_size,
"HIDDEN_DIM": hidden_dim,
"NUM_LAYERS": num_layers,
"NUM_HEADS": num_heads,
"MLP_DIM": mlp_dim,
"DROPOUT_RATE": dropout_rate,
"ATTENTION_DROPOUT_RATE": attention_dropout_rate,
"DROP_PATH_RATE": drop_path_rate,
"QKV_BIAS": qkv_bias,
"QK_SCALE": qk_scale,
"CLASSIFIER": classifier,
}
),
}
)

trunk = MODEL_TRUNKS_REGISTRY["vision_transformer"](cfg, model_name="vision_transformer")
trunk.model_config = cfg

return trunk, trunk.num_features


def resnet(
depth: int = 50,
width_multiplier: int = 1,
norm: RESNET_NORM_LAYER = None,
groupnorm_groups: int = 32,
standardize_convolutions: bool = False,
groups: int = 1,
zero_init_residual: bool = False,
width_per_group: int = 64,
layer4_stride: int = 2,
**kwargs,
) -> nn.Module:
if norm is None:
norm = RESNET_NORM_LAYER.BatchNorm
cfg = VISSLAdapter.get_model_config_template()
cfg.TRUNK = AttrDict(
{
"NAME": "resnet",
"RESNETS": AttrDict(
{
"DEPTH": depth,
"WIDTH_MULTIPLIER": width_multiplier,
"NORM": norm,
"GROUPNORM_GROUPS": groupnorm_groups,
"STANDARDIZE_CONVOLUTIONS": standardize_convolutions,
"GROUPS": groups,
"ZERO_INIT_RESIDUAL": zero_init_residual,
"WIDTH_PER_GROUP": width_per_group,
"LAYER4_STRIDE": layer4_stride,
}
),
}
)

trunk = MODEL_TRUNKS_REGISTRY["resnet"](cfg, model_name="resnet")
trunk.model_config = cfg

return trunk, 2048


def register_vissl_backbones(register: FlashRegistry):
if _VISSL_AVAILABLE:
for backbone in (vision_transformer, resnet):
register(backbone)
5 changes: 5 additions & 0 deletions flash/image/embedding/heads/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from flash.core.registry import FlashRegistry # noqa: F401
from flash.image.embedding.heads.vissl_heads import register_vissl_heads # noqa: F401

IMAGE_EMBEDDER_HEADS = FlashRegistry("embedder_heads")
register_vissl_heads(IMAGE_EMBEDDER_HEADS)
147 changes: 147 additions & 0 deletions flash/image/embedding/heads/vissl_heads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union

import torch
import torch.nn as nn

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _VISSL_AVAILABLE

if _VISSL_AVAILABLE:
from vissl.config.attr_dict import AttrDict
from vissl.models.heads import MODEL_HEADS_REGISTRY, register_model_head

from flash.image.embedding.vissl.adapter import VISSLAdapter
else:
AttrDict = object


class SimCLRHead(nn.Module):
def __init__(
self,
model_config: AttrDict,
dims: List[int] = [2048, 2048, 128],
use_bn: bool = True,
**kwargs,
) -> nn.Module:
super().__init__()

self.model_config = model_config
self.dims = dims
self.use_bn = use_bn

self.clf = self.create_mlp()

def create_mlp(self):
layers = []
last_dim = self.dims[0]

for dim in self.dims[1:-1]:
layers.append(nn.Linear(last_dim, dim))

if self.use_bn:
layers.append(
nn.BatchNorm1d(
dim,
eps=self.model_config.HEAD.BATCHNORM_EPS,
momentum=self.model_config.HEAD.BATCHNORM_MOMENTUM,
)
)

layers.append(nn.ReLU(inplace=True))
last_dim = dim

layers.append(nn.Linear(last_dim, self.dims[-1]))
return nn.Sequential(*layers)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.clf(x)


if _VISSL_AVAILABLE:
SimCLRHead = register_model_head("simclr_head")(SimCLRHead)


def simclr_head(
dims: List[int] = [2048, 2048, 128],
use_bn: bool = True,
**kwargs,
) -> nn.Module:
cfg = VISSLAdapter.get_model_config_template()
head_kwargs = {
"dims": dims,
"use_bn": use_bn,
}

cfg.HEAD.PARAMS.append(["simclr_head", head_kwargs])

head = MODEL_HEADS_REGISTRY["simclr_head"](cfg, **head_kwargs)
head.model_config = cfg

return head


def swav_head(
dims: List[int] = [2048, 2048, 128],
use_bn: bool = True,
num_clusters: Union[int, List[int]] = [3000],
use_bias: bool = True,
return_embeddings: bool = True,
skip_last_bn: bool = True,
normalize_feats: bool = True,
activation_name: str = "ReLU",
use_weight_norm_prototypes: bool = False,
**kwargs,
) -> nn.Module:
cfg = VISSLAdapter.get_model_config_template()
head_kwargs = {
"dims": dims,
"use_bn": use_bn,
"num_clusters": [num_clusters] if isinstance(num_clusters, int) else num_clusters,
"use_bias": use_bias,
"return_embeddings": return_embeddings,
"skip_last_bn": skip_last_bn,
"normalize_feats": normalize_feats,
"activation_name": activation_name,
"use_weight_norm_prototypes": use_weight_norm_prototypes,
}

cfg.HEAD.PARAMS.append(["swav_head", head_kwargs])

head = MODEL_HEADS_REGISTRY["swav_head"](cfg, **head_kwargs)
head.model_config = cfg

return head


def barlow_twins_head(**kwargs) -> nn.Module:
return simclr_head(dims=[2048, 8192, 8192, 8192], **kwargs)


def dino_head(**kwargs) -> nn.Module:
return swav_head(
dims=[384, 2048, 2048, 256],
use_bn=False,
return_embeddings=False,
activation_name="GELU",
num_clusters=[65536],
use_weight_norm_prototypes=True,
**kwargs,
)


def register_vissl_heads(register: FlashRegistry):
for ssl_head in (swav_head, simclr_head, dino_head, barlow_twins_head):
register(ssl_head)
5 changes: 5 additions & 0 deletions flash/image/embedding/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from flash.core.registry import FlashRegistry # noqa: F401
from flash.image.embedding.losses.vissl_losses import register_vissl_losses # noqa: F401

IMAGE_EMBEDDER_LOSS_FUNCTIONS = FlashRegistry("embedder_losses")
register_vissl_losses(IMAGE_EMBEDDER_LOSS_FUNCTIONS)
Loading

0 comments on commit d190c76

Please sign in to comment.