diff --git a/flash/vision/backbones.py b/flash/vision/backbones.py index d7a8fb9906..b7c764320a 100644 --- a/flash/vision/backbones.py +++ b/flash/vision/backbones.py @@ -18,6 +18,7 @@ from functools import partial from typing import Tuple +import torch from pytorch_lightning import LightningModule from pytorch_lightning.utilities import _BOLTS_AVAILABLE, rank_zero_warn from torch import nn as nn @@ -180,3 +181,32 @@ def _fn_timm( IMAGE_CLASSIFIER_BACKBONES( fn=catch_url_error(partial(_fn_timm, model_name)), name=model_name, namespace="vision", package="timm" ) + + +# Paper: Emerging Properties in Self-Supervised Vision Transformers +# https://arxiv.org/abs/2104.14294 from Mathilde Caron and al. (29 Apr 2021) +# weights from https://github.com/facebookresearch/dino +def dino_deits16(*_, **__): + backbone = torch.hub.load('facebookresearch/dino:main', 'dino_deits16') + return backbone, 384 + + +def dino_deits8(*_, **__): + backbone = torch.hub.load('facebookresearch/dino:main', 'dino_deits8') + return backbone, 384 + + +def dino_vitb16(*_, **__): + backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16') + return backbone, 768 + + +def dino_vitb8(*_, **__): + backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8') + return backbone, 768 + + +IMAGE_CLASSIFIER_BACKBONES(dino_deits16) +IMAGE_CLASSIFIER_BACKBONES(dino_deits8) +IMAGE_CLASSIFIER_BACKBONES(dino_vitb16) +IMAGE_CLASSIFIER_BACKBONES(dino_vitb8) diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index f11a50eb01..2d68e07341 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -105,7 +105,6 @@ def __init__( head = head(num_features, num_classes) if isinstance(head, FunctionType) else head self.head = head or nn.Sequential( - nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(num_features, num_classes), ) diff --git a/flash_examples/finetuning/image_classification.py b/flash_examples/finetuning/image_classification.py index 4a93ec1785..eb863f4e9f 100644 --- a/flash_examples/finetuning/image_classification.py +++ b/flash_examples/finetuning/image_classification.py @@ -48,7 +48,7 @@ def fn_resnet(pretrained: bool = True): print(ImageClassifier.available_backbones()) # 4. Build the model -model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) +model = ImageClassifier(backbone="dino_vitb16", num_classes=datamodule.num_classes) # 5. Create the trainer. trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1)