Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[feat] Add Dino (#259)
Browse files Browse the repository at this point in the history
* update

* update
  • Loading branch information
tchaton authored May 2, 2021
1 parent 8007c26 commit 8bf5d76
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 6 deletions.
30 changes: 30 additions & 0 deletions flash/vision/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from functools import partial
from typing import Tuple

import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, rank_zero_warn
from torch import nn as nn
Expand Down Expand Up @@ -180,3 +181,32 @@ def _fn_timm(
IMAGE_CLASSIFIER_BACKBONES(
fn=catch_url_error(partial(_fn_timm, model_name)), name=model_name, namespace="vision", package="timm"
)


# Paper: Emerging Properties in Self-Supervised Vision Transformers
# https://arxiv.org/abs/2104.14294 from Mathilde Caron and al. (29 Apr 2021)
# weights from https://github.com/facebookresearch/dino
def dino_deits16(*_, **__):
backbone = torch.hub.load('facebookresearch/dino:main', 'dino_deits16')
return backbone, 384


def dino_deits8(*_, **__):
backbone = torch.hub.load('facebookresearch/dino:main', 'dino_deits8')
return backbone, 384


def dino_vitb16(*_, **__):
backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
return backbone, 768


def dino_vitb8(*_, **__):
backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8')
return backbone, 768


IMAGE_CLASSIFIER_BACKBONES(dino_deits16)
IMAGE_CLASSIFIER_BACKBONES(dino_deits8)
IMAGE_CLASSIFIER_BACKBONES(dino_vitb16)
IMAGE_CLASSIFIER_BACKBONES(dino_vitb8)
8 changes: 3 additions & 5 deletions flash/vision/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,10 @@ def __init__(
self.backbone, num_features = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs)

head = head(num_features, num_classes) if isinstance(head, FunctionType) else head
self.head = head or nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(num_features, num_classes),
)
self.head = head or nn.Sequential(nn.Linear(num_features, num_classes), )

def forward(self, x) -> torch.Tensor:
x = self.backbone(x)
if x.dim() == 4:
x = x.mean(-1).mean(-1)
return self.head(x)
2 changes: 1 addition & 1 deletion flash_examples/finetuning/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def fn_resnet(pretrained: bool = True):
print(ImageClassifier.available_backbones())

# 4. Build the model
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
model = ImageClassifier(backbone="dino_vitb16", num_classes=datamodule.num_classes)

# 5. Create the trainer.
trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1)
Expand Down

0 comments on commit 8bf5d76

Please sign in to comment.