Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

VISSL initial integration #682

Merged
merged 68 commits into from
Sep 21, 2021
Merged
Show file tree
Hide file tree
Changes from 64 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
6f907e2
tests
ananyahjha93 Sep 10, 2021
b1fab6e
merge
ananyahjha93 Sep 10, 2021
54c2efe
.
ananyahjha93 Sep 10, 2021
244e7a5
.
ananyahjha93 Sep 10, 2021
2bce93e
hooks cleanup
ananyahjha93 Sep 10, 2021
603b421
.
ananyahjha93 Sep 10, 2021
8307e84
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2021
5153af9
multi-gpu
ananyahjha93 Sep 10, 2021
5061d6d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2021
95cad71
strategies
ananyahjha93 Sep 11, 2021
4354a78
Merge branch 'feature/vissl' of https://github.com/PyTorchLightning/l…
ananyahjha93 Sep 11, 2021
a120ff8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 11, 2021
faccde3
.
ananyahjha93 Sep 12, 2021
8c18e47
merge
ananyahjha93 Sep 12, 2021
a4c80c0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2021
d5843eb
Merge branch 'master' into feature/vissl
ananyahjha93 Sep 12, 2021
37ca68b
Updates
ethanwharris Sep 12, 2021
a9870e7
.
ananyahjha93 Sep 12, 2021
bc0af99
Merge branch 'feature/vissl' of https://github.com/PyTorchLightning/l…
ananyahjha93 Sep 12, 2021
3bd3e7d
.
ananyahjha93 Sep 13, 2021
2f1f07c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2021
c567009
gtg, docstrings, cpu testing
ananyahjha93 Sep 13, 2021
4d85ed4
merge
ananyahjha93 Sep 13, 2021
964b97c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2021
c3b3863
.
ananyahjha93 Sep 13, 2021
45dae27
Merge branch 'feature/vissl' of https://github.com/PyTorchLightning/l…
ananyahjha93 Sep 13, 2021
f3fbaf6
test, exmaple
ananyahjha93 Sep 13, 2021
7bce90c
.
ananyahjha93 Sep 14, 2021
4583bcd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2021
779859e
transforms
ananyahjha93 Sep 14, 2021
41eab70
conflict
ananyahjha93 Sep 14, 2021
cfb36a4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2021
59617fc
Merge branch 'master' into feature/vissl
ananyahjha93 Sep 14, 2021
cd15de8
.
ananyahjha93 Sep 15, 2021
961c508
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 15, 2021
b46ae7d
tests
ananyahjha93 Sep 10, 2021
7c5feb9
merge
ananyahjha93 Sep 10, 2021
4901e60
.
ananyahjha93 Sep 10, 2021
26b9e5b
.
ananyahjha93 Sep 10, 2021
964f100
hooks cleanup
ananyahjha93 Sep 10, 2021
53d39ec
.
ananyahjha93 Sep 10, 2021
349eac0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2021
aa52179
multi-gpu
ananyahjha93 Sep 10, 2021
ddd5b5b
strategies
ananyahjha93 Sep 11, 2021
99438d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2021
b82557c
.
ananyahjha93 Sep 12, 2021
2855c4e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 11, 2021
760316e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2021
92a5c0e
.
ananyahjha93 Sep 12, 2021
f1dd5cc
Updates
ethanwharris Sep 12, 2021
1849c82
.
ananyahjha93 Sep 13, 2021
8db9270
gtg, docstrings, cpu testing
ananyahjha93 Sep 13, 2021
51869bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2021
785453d
.
ananyahjha93 Sep 13, 2021
7550ac2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2021
77c391f
tests
ananyahjha93 Sep 21, 2021
e4e0543
merge
ananyahjha93 Sep 21, 2021
01cf3e4
imports
ananyahjha93 Sep 21, 2021
7695812
Fix some import issues
ethanwharris Sep 21, 2021
cb3f849
Add classy vision master install
ethanwharris Sep 21, 2021
db276a8
Drop JIT test
ethanwharris Sep 21, 2021
be97ef2
pull
ananyahjha93 Sep 21, 2021
c3973d9
Small fixes
ethanwharris Sep 21, 2021
da93dbb
Style
ethanwharris Sep 21, 2021
5fc9f8e
Fix
ethanwharris Sep 21, 2021
166006c
Updates
ethanwharris Sep 21, 2021
81cb161
tests
ananyahjha93 Sep 21, 2021
4f68899
changelog
ananyahjha93 Sep 21, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ jobs:
- name: Install vissl
if: matrix.topic[1] == 'image_extras'
run: |
pip install git+https://github.com/facebookresearch/vissl.git@master
pip install git+https://github.com/facebookresearch/ClassyVision.git
pip install git+https://github.com/facebookresearch/vissl.git

- name: Install graph test dependencies
if: matrix.topic[0] == 'graph'
Expand Down
6 changes: 5 additions & 1 deletion flash/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,11 @@ def _register_function(
if not callable(fn):
raise MisconfigurationException(f"You can only register a callable, found: {fn}")

name = name or fn.__name__
if name is None:
if hasattr(fn, "func"):
name = fn.func.__name__
else:
name = fn.__name__
ananyahjha93 marked this conversation as resolved.
Show resolved Hide resolved

if self._verbose:
rank_zero_info(f"Registering: {fn.__name__} function with name: {name} and metadata: {metadata}")
Expand Down
1 change: 1 addition & 0 deletions flash/core/utilities/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,4 @@ def __str__(self):
_FAIRSEQ = Provider("PyTorch/fairseq", "https://github.com/pytorch/fairseq")
_OPEN3D_ML = Provider("Intelligent Systems Lab Org/Open3D-ML", "https://github.com/isl-org/Open3D-ML")
_PYTORCHVIDEO = Provider("Facebook Research/PyTorchVideo", "https://github.com/facebookresearch/pytorchvideo")
_VISSL = Provider("Facebook Research/vissl", "https://github.com/facebookresearch/vissl")
5 changes: 5 additions & 0 deletions flash/image/embedding/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from flash.core.registry import FlashRegistry # noqa: F401
from flash.image.embedding.backbones.vissl_backbones import register_vissl_backbones # noqa: F401

IMAGE_EMBEDDER_BACKBONES = FlashRegistry("embedder_backbones")
register_vissl_backbones(IMAGE_EMBEDDER_BACKBONES)
117 changes: 117 additions & 0 deletions flash/image/embedding/backbones/vissl_backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.nn as nn

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _VISSL_AVAILABLE

if _VISSL_AVAILABLE:
from vissl.config.attr_dict import AttrDict
from vissl.models.model_helpers import RESNET_NORM_LAYER
from vissl.models.trunks import MODEL_TRUNKS_REGISTRY

from flash.image.embedding.vissl.adapter import VISSLAdapter
else:
RESNET_NORM_LAYER = object


def vision_transformer(
image_size: int = 224,
patch_size: int = 16,
hidden_dim: int = 384,
num_layers: int = 12,
num_heads: int = 6,
mlp_dim: int = 1532,
dropout_rate: float = 0,
attention_dropout_rate: float = 0,
drop_path_rate: float = 0,
qkv_bias: bool = True,
qk_scale: bool = False,
classifier: str = "token",
**kwargs,
) -> nn.Module:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some docstrings would be good. Same below.


cfg = VISSLAdapter.get_model_config_template()
cfg.TRUNK = AttrDict(
{
"NAME": "vision_transformer",
"VISION_TRANSFORMERS": AttrDict(
{
"IMAGE_SIZE": image_size,
"PATCH_SIZE": patch_size,
"HIDDEN_DIM": hidden_dim,
"NUM_LAYERS": num_layers,
"NUM_HEADS": num_heads,
"MLP_DIM": mlp_dim,
"DROPOUT_RATE": dropout_rate,
"ATTENTION_DROPOUT_RATE": attention_dropout_rate,
"DROP_PATH_RATE": drop_path_rate,
"QKV_BIAS": qkv_bias,
"QK_SCALE": qk_scale,
"CLASSIFIER": classifier,
}
),
}
)

trunk = MODEL_TRUNKS_REGISTRY["vision_transformer"](cfg, model_name="vision_transformer")
trunk.model_config = cfg

return trunk, trunk.num_features


def resnet(
depth: int = 50,
width_multiplier: int = 1,
norm: RESNET_NORM_LAYER = None,
groupnorm_groups: int = 32,
standardize_convolutions: bool = False,
groups: int = 1,
zero_init_residual: bool = False,
width_per_group: int = 64,
layer4_stride: int = 2,
**kwargs,
) -> nn.Module:
if norm is None:
norm = RESNET_NORM_LAYER.BatchNorm
cfg = VISSLAdapter.get_model_config_template()
cfg.TRUNK = AttrDict(
{
"NAME": "resnet",
"RESNETS": AttrDict(
{
"DEPTH": depth,
"WIDTH_MULTIPLIER": width_multiplier,
"NORM": norm,
"GROUPNORM_GROUPS": groupnorm_groups,
"STANDARDIZE_CONVOLUTIONS": standardize_convolutions,
"GROUPS": groups,
"ZERO_INIT_RESIDUAL": zero_init_residual,
"WIDTH_PER_GROUP": width_per_group,
"LAYER4_STRIDE": layer4_stride,
}
),
}
)

trunk = MODEL_TRUNKS_REGISTRY["resnet"](cfg, model_name="resnet")
trunk.model_config = cfg

return trunk, 2048


def register_vissl_backbones(register: FlashRegistry):
if _VISSL_AVAILABLE:
for backbone in (vision_transformer, resnet):
register(backbone)
5 changes: 5 additions & 0 deletions flash/image/embedding/heads/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from flash.core.registry import FlashRegistry # noqa: F401
from flash.image.embedding.heads.vissl_heads import register_vissl_heads # noqa: F401

IMAGE_EMBEDDER_HEADS = FlashRegistry("embedder_heads")
register_vissl_heads(IMAGE_EMBEDDER_HEADS)
147 changes: 147 additions & 0 deletions flash/image/embedding/heads/vissl_heads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union

import torch
import torch.nn as nn

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _VISSL_AVAILABLE

if _VISSL_AVAILABLE:
from vissl.config.attr_dict import AttrDict
from vissl.models.heads import MODEL_HEADS_REGISTRY, register_model_head

from flash.image.embedding.vissl.adapter import VISSLAdapter
else:
AttrDict = object


class SimCLRHead(nn.Module):
ananyahjha93 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
model_config: AttrDict,
dims: List[int] = [2048, 2048, 128],
use_bn: bool = True,
**kwargs,
) -> nn.Module:
super().__init__()

self.model_config = model_config
self.dims = dims
self.use_bn = use_bn

self.clf = self.create_mlp()

def create_mlp(self):
layers = []
last_dim = self.dims[0]

for dim in self.dims[1:-1]:
layers.append(nn.Linear(last_dim, dim))

if self.use_bn:
layers.append(
nn.BatchNorm1d(
dim,
eps=self.model_config.HEAD.BATCHNORM_EPS,
momentum=self.model_config.HEAD.BATCHNORM_MOMENTUM,
)
)

layers.append(nn.ReLU(inplace=True))
last_dim = dim

layers.append(nn.Linear(last_dim, self.dims[-1]))
return nn.Sequential(*layers)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.clf(x)


if _VISSL_AVAILABLE:
SimCLRHead = register_model_head("simclr_head")(SimCLRHead)


def simclr_head(
dims: List[int] = [2048, 2048, 128],
use_bn: bool = True,
**kwargs,
) -> nn.Module:
cfg = VISSLAdapter.get_model_config_template()
head_kwargs = {
"dims": dims,
"use_bn": use_bn,
}

cfg.HEAD.PARAMS.append(["simclr_head", head_kwargs])

head = MODEL_HEADS_REGISTRY["simclr_head"](cfg, **head_kwargs)
head.model_config = cfg

return head


def swav_head(
dims: List[int] = [2048, 2048, 128],
use_bn: bool = True,
num_clusters: Union[int, List[int]] = [3000],
use_bias: bool = True,
return_embeddings: bool = True,
skip_last_bn: bool = True,
normalize_feats: bool = True,
activation_name: str = "ReLU",
use_weight_norm_prototypes: bool = False,
**kwargs,
) -> nn.Module:
cfg = VISSLAdapter.get_model_config_template()
head_kwargs = {
"dims": dims,
"use_bn": use_bn,
"num_clusters": [num_clusters] if isinstance(num_clusters, int) else num_clusters,
"use_bias": use_bias,
"return_embeddings": return_embeddings,
"skip_last_bn": skip_last_bn,
"normalize_feats": normalize_feats,
"activation_name": activation_name,
"use_weight_norm_prototypes": use_weight_norm_prototypes,
}

cfg.HEAD.PARAMS.append(["swav_head", head_kwargs])

head = MODEL_HEADS_REGISTRY["swav_head"](cfg, **head_kwargs)
head.model_config = cfg

return head


def barlow_twins_head(**kwargs) -> nn.Module:
return simclr_head(dims=[2048, 8192, 8192, 8192], **kwargs)


def dino_head(**kwargs) -> nn.Module:
return swav_head(
dims=[384, 2048, 2048, 256],
use_bn=False,
return_embeddings=False,
activation_name="GELU",
num_clusters=[65536],
use_weight_norm_prototypes=True,
**kwargs,
)


def register_vissl_heads(register: FlashRegistry):
for ssl_head in (swav_head, simclr_head, dino_head, barlow_twins_head):
register(ssl_head)
5 changes: 5 additions & 0 deletions flash/image/embedding/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from flash.core.registry import FlashRegistry # noqa: F401
from flash.image.embedding.losses.vissl_losses import register_vissl_losses # noqa: F401

IMAGE_EMBEDDER_LOSS_FUNCTIONS = FlashRegistry("embedder_losses")
register_vissl_losses(IMAGE_EMBEDDER_LOSS_FUNCTIONS)
Loading