Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

VISSL initial integration #682

Merged
merged 68 commits into from
Sep 21, 2021
Merged
Changes from 1 commit
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
6f907e2
tests
ananyahjha93 Sep 10, 2021
b1fab6e
merge
ananyahjha93 Sep 10, 2021
54c2efe
.
ananyahjha93 Sep 10, 2021
244e7a5
.
ananyahjha93 Sep 10, 2021
2bce93e
hooks cleanup
ananyahjha93 Sep 10, 2021
603b421
.
ananyahjha93 Sep 10, 2021
8307e84
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2021
5153af9
multi-gpu
ananyahjha93 Sep 10, 2021
5061d6d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2021
95cad71
strategies
ananyahjha93 Sep 11, 2021
4354a78
Merge branch 'feature/vissl' of https://github.com/PyTorchLightning/l…
ananyahjha93 Sep 11, 2021
a120ff8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 11, 2021
faccde3
.
ananyahjha93 Sep 12, 2021
8c18e47
merge
ananyahjha93 Sep 12, 2021
a4c80c0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2021
d5843eb
Merge branch 'master' into feature/vissl
ananyahjha93 Sep 12, 2021
37ca68b
Updates
ethanwharris Sep 12, 2021
a9870e7
.
ananyahjha93 Sep 12, 2021
bc0af99
Merge branch 'feature/vissl' of https://github.com/PyTorchLightning/l…
ananyahjha93 Sep 12, 2021
3bd3e7d
.
ananyahjha93 Sep 13, 2021
2f1f07c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2021
c567009
gtg, docstrings, cpu testing
ananyahjha93 Sep 13, 2021
4d85ed4
merge
ananyahjha93 Sep 13, 2021
964b97c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2021
c3b3863
.
ananyahjha93 Sep 13, 2021
45dae27
Merge branch 'feature/vissl' of https://github.com/PyTorchLightning/l…
ananyahjha93 Sep 13, 2021
f3fbaf6
test, exmaple
ananyahjha93 Sep 13, 2021
7bce90c
.
ananyahjha93 Sep 14, 2021
4583bcd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2021
779859e
transforms
ananyahjha93 Sep 14, 2021
41eab70
conflict
ananyahjha93 Sep 14, 2021
cfb36a4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2021
59617fc
Merge branch 'master' into feature/vissl
ananyahjha93 Sep 14, 2021
cd15de8
.
ananyahjha93 Sep 15, 2021
961c508
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 15, 2021
b46ae7d
tests
ananyahjha93 Sep 10, 2021
7c5feb9
merge
ananyahjha93 Sep 10, 2021
4901e60
.
ananyahjha93 Sep 10, 2021
26b9e5b
.
ananyahjha93 Sep 10, 2021
964f100
hooks cleanup
ananyahjha93 Sep 10, 2021
53d39ec
.
ananyahjha93 Sep 10, 2021
349eac0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2021
aa52179
multi-gpu
ananyahjha93 Sep 10, 2021
ddd5b5b
strategies
ananyahjha93 Sep 11, 2021
99438d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2021
b82557c
.
ananyahjha93 Sep 12, 2021
2855c4e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 11, 2021
760316e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2021
92a5c0e
.
ananyahjha93 Sep 12, 2021
f1dd5cc
Updates
ethanwharris Sep 12, 2021
1849c82
.
ananyahjha93 Sep 13, 2021
8db9270
gtg, docstrings, cpu testing
ananyahjha93 Sep 13, 2021
51869bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2021
785453d
.
ananyahjha93 Sep 13, 2021
7550ac2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2021
77c391f
tests
ananyahjha93 Sep 21, 2021
e4e0543
merge
ananyahjha93 Sep 21, 2021
01cf3e4
imports
ananyahjha93 Sep 21, 2021
7695812
Fix some import issues
ethanwharris Sep 21, 2021
cb3f849
Add classy vision master install
ethanwharris Sep 21, 2021
db276a8
Drop JIT test
ethanwharris Sep 21, 2021
be97ef2
pull
ananyahjha93 Sep 21, 2021
c3973d9
Small fixes
ethanwharris Sep 21, 2021
da93dbb
Style
ethanwharris Sep 21, 2021
5fc9f8e
Fix
ethanwharris Sep 21, 2021
166006c
Updates
ethanwharris Sep 21, 2021
81cb161
tests
ananyahjha93 Sep 21, 2021
4f68899
changelog
ananyahjha93 Sep 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
.
ananyahjha93 committed Sep 21, 2021
commit b82557cfe094078e5d11ba57464b40f24c2066bf
66 changes: 53 additions & 13 deletions flash/image/embedding/backbones/vissl_backbones.py
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@
if _VISSL_AVAILABLE:
from vissl.config.attr_dict import AttrDict
from vissl.models.trunks import MODEL_TRUNKS_REGISTRY
from vissl.models.model_helpers import RESNET_NORM_LAYER

from flash.image.embedding.vissl.adapter import VISSLAdapter

@@ -45,18 +46,18 @@ def vision_transformer(
"NAME": "vision_transformer",
"VISION_TRANSFORMERS": AttrDict(
{
"image_size": image_size,
"patch_size": patch_size,
"hidden_dim": hidden_dim,
"num_layers": num_layers,
"num_heads": num_heads,
"mlp_dim": mlp_dim,
"dropout_rate": dropout_rate,
"attention_dropout_rate": attention_dropout_rate,
"drop_path_rate": drop_path_rate,
"qkv_bias": qkv_bias,
"qk_scale": qk_scale,
"classifier": classifier,
"IMAGE_SIZE": image_size,
"PATCH_SIZE": patch_size,
"HIDDEN_DIM": hidden_dim,
"NUM_LAYERS": num_layers,
"NUM_HEADS": num_heads,
"MLP_DIM": mlp_dim,
"DROPOUT_RATE": dropout_rate,
"ATTENTION_DROPOUT_RATE": attention_dropout_rate,
"DROP_PATH_RATE": drop_path_rate,
"QKV_BIAS": qkv_bias,
"QK_SCALE": qk_scale,
"CLASSIFIER": classifier,
}
),
}
@@ -68,5 +69,44 @@ def vision_transformer(
return trunk, trunk.num_features


def resnet(
depth: int = 50,
width_multiplier: int = 1,
norm: RESNET_NORM_LAYER = RESNET_NORM_LAYER.BatchNorm,
groupnorm_groups: int = 32,
standardize_convolutions: bool = False,
groups: int = 1,
zero_init_residual: bool = False,
width_per_group: int = 64,
layer4_stride: int = 2,
**kwargs,
) -> nn.Module:
cfg = VISSLAdapter.get_model_config_template()
cfg.TRUNK = AttrDict(
{
"NAME": "resnet",
"RESNETS": AttrDict(
{
'DEPTH': depth,
'WIDTH_MULTIPLIER': width_multiplier,
'NORM': norm,
'GROUPNORM_GROUPS': groupnorm_groups,
'STANDARDIZE_CONVOLUTIONS': standardize_convolutions,
'GROUPS': groups,
'ZERO_INIT_RESIDUAL': zero_init_residual,
'WIDTH_PER_GROUP': width_per_group,
'LAYER4_STRIDE': layer4_stride,
}
),
}
)

trunk = MODEL_TRUNKS_REGISTRY["resnet"](cfg, model_name="resnet")
trunk.model_config = cfg

return trunk, 2048


def register_vissl_backbones(register: FlashRegistry):
register(vision_transformer)
for backbone in (vision_transformer, resnet):
register(backbone)
84 changes: 79 additions & 5 deletions flash/image/embedding/heads/vissl_heads.py
Original file line number Diff line number Diff line change
@@ -12,22 +12,86 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
from functools import partial

import torch
import torch.nn as nn

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _VISSL_AVAILABLE

if _VISSL_AVAILABLE:
from vissl.models.heads import MODEL_HEADS_REGISTRY
from vissl.models.heads import MODEL_HEADS_REGISTRY, register_model_head
from vissl.config.attr_dict import AttrDict

from flash.image.embedding.vissl.adapter import VISSLAdapter


@register_model_head("simclr_head")
class SimCLRHead(nn.Module):
ananyahjha93 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
model_config: AttrDict,
dims: List[int] = [2048, 2048, 128],
use_bn: bool = True,
**kwargs,
) -> nn.Module:
super().__init__()

self.model_config = model_config
self.dims = dims
self.use_bn = use_bn

self.clf = self.create_mlp()

def create_mlp(self):
layers = []
last_dim = self.dims[0]

for dim in self.dims[1:-1]:
layers.append(nn.Linear(last_dim, dim))

if self.use_bn:
layers.append(
nn.BatchNorm1d(
dim,
eps=self.model_config.HEAD.BATCHNORM_EPS,
momentum=self.model_config.HEAD.BATCHNORM_MOMENTUM,
)
)

layers.append(nn.ReLU(inplace=True))

layers.append(nn.Linear(last_dim, self.dims[-1]))
return nn.Sequential(*layers)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.clf(x)


def simclr_head(
dims: List[int] = [2048, 2048, 128],
use_bn: bool = True,
**kwargs,
) -> nn.Module:
cfg = VISSLAdapter.get_model_config_template()
head_kwargs = {
"dims": dims,
"use_bn": use_bn,
}

cfg.HEAD.PARAMS.append(["simclr_head", head_kwargs])

head = MODEL_HEADS_REGISTRY["simclr_head"](cfg, **head_kwargs)
head.model_config = cfg

return head


def swav_head(
dims: List[int] = [384, 2048, 2048, 256],
use_bn: bool = False,
num_clusters: Union[int, List[int]] = [65536],
dims: List[int] = [2048, 2048, 128],
use_bn: bool = True,
num_clusters: Union[int, List[int]] = [3000],
use_bias: bool = True,
return_embeddings: bool = False,
skip_last_bn: bool = True,
@@ -57,5 +121,15 @@ def swav_head(
return head


barlow_twins_head = partial(simclr_head, dims=[2048, 8192, 8192, 8192])
dino_head = partial(
swav_head,
dims=[384, 2048, 2048, 256],
use_bn=False,
num_clusters=[65536],
)


def register_vissl_heads(register: FlashRegistry):
register(swav_head)
for ssl_head in (swav_head, simclr_head, dino_head, barlow_twins_head):
register(ssl_head)
2 changes: 2 additions & 0 deletions flash/image/embedding/losses/vissl_losses.py
Original file line number Diff line number Diff line change
@@ -141,13 +141,15 @@ def moco_loss(
queue_size: int = 65536,
momentum: float = 0.999,
temperature: int = 0.2,
shuffle_batch: bool = True,
):
cfg = AttrDict(
{
"embedding_dim": embedding_dim,
"queue_size": queue_size,
"momentum": momentum,
"temperature": temperature,
"shuffle_batch": shuffle_batch,
}
)

3 changes: 1 addition & 2 deletions flash/image/embedding/model.py
Original file line number Diff line number Diff line change
@@ -80,8 +80,7 @@ def __init__(
# assert embedding_dim == num_features

metadata = self.training_strategy_registry.get(training_strategy, with_metadata=True)
loss_fn, head = metadata["fn"](**kwargs)
hooks = metadata["metadata"]["hooks"]
loss_fn, head, hooks = metadata["fn"](**kwargs)

adapter = metadata["metadata"]["adapter"].from_task(
self,
30 changes: 9 additions & 21 deletions flash/image/embedding/strategies/vissl_strategies.py
Original file line number Diff line number Diff line change
@@ -17,61 +17,49 @@

if _VISSL_AVAILABLE:
from vissl.hooks.dino_hooks import DINOHook
from vissl.hooks.moco_hooks import
from vissl.hooks.swav
from vissl.hooks.moco_hooks import MoCoHook
from vissl.hooks.swav_hooks import SwAVUpdateQueueScoresHook, NormalizePrototypesHook

from flash.image.embedding.heads import IMAGE_EMBEDDER_HEADS
from flash.image.embedding.losses import IMAGE_EMBEDDER_LOSS_FUNCTIONS
from flash.image.embedding.vissl.adapter import VISSLAdapter


HOOKS_DICT = {
"dino": [DINOHook()],
"moco": [],
"swav": [],
}


def dino(head: str = "swav_head", **kwargs):
def dino(head: str = "dino_head", **kwargs):
loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("dino_loss")(**kwargs)
head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs)

return loss_fn, head
return loss_fn, head, [DINOHook()]


def swav(head: str = "swav_head", **kwargs):
loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("swav_loss")(**kwargs)
head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs)

return loss_fn, head
return loss_fn, head, [SwAVUpdateQueueScoresHook(), NormalizePrototypesHook()]


def simclr(head: str = "simclr_head", **kwargs):
loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("simclr_loss")(**kwargs)
head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs)

return loss_fn, head
return loss_fn, head, []


def moco(head: str = "simclr_head", **kwargs):
loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("moco_loss")(**kwargs)
head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs)

return loss_fn, head
return loss_fn, head, [MoCoHook(loss_fn.loss_config.momentum, loss_fn.loss_config.shuffle_batch)]


def barlow_twins(head: str = "barlow_twins_head", **kwargs):
loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("barlow_twins_loss")(**kwargs)
head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs)

return loss_fn, head
return loss_fn, head, []


def register_vissl_strategies(register: FlashRegistry):
for training_strategy in (dino, swav, simclr, moco, barlow_twins):
register(
training_strategy,
hooks=HOOKS_DICT[training_strategy.__name__],
adapter=VISSLAdapter,
providers=_VISSL
)
register(training_strategy, name=training_strategy.__name__, adapter=VISSLAdapter, providers=_VISSL)
6 changes: 6 additions & 0 deletions flash/image/embedding/vissl/adapter.py
Original file line number Diff line number Diff line change
@@ -170,6 +170,12 @@ def get_model_config_template():
}
),
"_MODEL_INIT_SEED": 0,
"ACTIVATION_CHECKPOINTING": AttrDict(
{
"USE_ACTIVATION_CHECKPOINTING": False,
"NUM_ACTIVATION_CHECKPOINTING_SPLITS": 2,
}
),
}
)