diff --git a/CHANGELOG.md b/CHANGELOG.md index 117c68ebb0..f94f4bb30e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,10 +9,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Added support for (input, target) style datasets (e.g. torchvision) to the from_datasets method ([#552](https://github.com/PyTorchLightning/lightning-flash/pull/552)) + - Added support for `from_csv` and `from_data_frame` to `ImageClassificationData` ([#556](https://github.com/PyTorchLightning/lightning-flash/pull/556)) +- Added SimCLR, SwAV, Barlow-twins pretrained weights for resnet50 backbone in ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560)) + ### Changed +- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560)) + +- Removed bolts pretrained weights for SSL from ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560)) ### Deprecated diff --git a/docs/source/template/backbones.rst b/docs/source/template/backbones.rst index 82c629430f..c44860a670 100644 --- a/docs/source/template/backbones.rst +++ b/docs/source/template/backbones.rst @@ -34,11 +34,11 @@ Here's another example with a slightly more complex model: :language: python :pyobject: load_mlp_128_256 -Here's a more advanced example, which adds ``SimCLR`` to the ``IMAGE_CLASSIFIER_BACKBONES``, from `flash/image/backbones.py `_: +Here's a another example, which adds ``DINO`` pretrained model from PyTorch Hub to the ``IMAGE_CLASSIFIER_BACKBONES``, from `flash/image/backbones.py `_: .. literalinclude:: ../../../flash/image/backbones.py :language: python - :pyobject: load_simclr_imagenet + :pyobject: dino_vitb16 ------ diff --git a/flash/image/backbones.py b/flash/image/backbones.py index 103d3c37ee..9a54529a38 100644 --- a/flash/image/backbones.py +++ b/flash/image/backbones.py @@ -12,19 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -import os import urllib.error -import warnings from functools import partial -from typing import Tuple +from typing import Tuple, Union import torch -from pytorch_lightning import LightningModule from pytorch_lightning.utilities import rank_zero_warn from torch import nn +from torch.hub import load_state_dict_from_url from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TIMM_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _TIMM_AVAILABLE, _TORCHVISION_AVAILABLE if _TIMM_AVAILABLE: import timm @@ -33,21 +31,11 @@ import torchvision from torchvision.models.detection.backbone_utils import resnet_fpn_backbone -if _BOLTS_AVAILABLE: - if os.getenv("WARN_MISSING_PACKAGE") == "0": - with warnings.catch_warnings(record=True) as w: - from pl_bolts.models.self_supervised import SimCLR, SwAV - else: - from pl_bolts.models.self_supervised import SimCLR, SwAV - -ROOT_S3_BUCKET = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com" - MOBILENET_MODELS = ["mobilenet_v2"] VGG_MODELS = ["vgg11", "vgg13", "vgg16", "vgg19"] RESNET_MODELS = ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d"] DENSENET_MODELS = ["densenet121", "densenet169", "densenet161"] TORCHVISION_MODELS = MOBILENET_MODELS + VGG_MODELS + RESNET_MODELS + DENSENET_MODELS -BOLTS_MODELS = ["simclr-imagenet", "swav-imagenet"] IMAGE_CLASSIFIER_BACKBONES = FlashRegistry("backbones") OBJ_DETECTION_BACKBONES = FlashRegistry("backbones") @@ -71,27 +59,18 @@ def wrapper(*args, pretrained=False, **kwargs): return wrapper -@IMAGE_CLASSIFIER_BACKBONES(name="simclr-imagenet", namespace="vision", package="bolts") -def load_simclr_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt", **_): - simclr: LightningModule = SimCLR.load_from_checkpoint(path_or_url, strict=False) - # remove the last two layers & turn it into a Sequential model - backbone = nn.Sequential(*list(simclr.encoder.children())[:-2]) - return backbone, 2048 - - -@IMAGE_CLASSIFIER_BACKBONES(name="swav-imagenet", namespace="vision", package="bolts") -def load_swav_imagenet( - path_or_url: str = f"{ROOT_S3_BUCKET}/swav/swav_imagenet/swav_imagenet.pth.tar", - **_, -) -> Tuple[nn.Module, int]: - swav: LightningModule = SwAV.load_from_checkpoint(path_or_url, strict=True) - # remove the last two layers & turn it into a Sequential model - backbone = nn.Sequential(*list(swav.model.children())[:-2]) - return backbone, 2048 - - if _TORCHVISION_AVAILABLE: + HTTPS_VISSL = "https://dl.fbaipublicfiles.com/vissl/model_zoo/" + RESNET50_WEIGHTS_PATHS = { + "supervised": None, + "simclr": HTTPS_VISSL + "simclr_rn50_800ep_simclr_8node_resnet_16_07_20.7e8feed1/" + "model_final_checkpoint_phase799.torch", + "swav": HTTPS_VISSL + "swav_in1k_rn50_800ep_swav_8node_resnet_27_07_20.a0a6b676/" + "model_final_checkpoint_phase799.torch", + "barlow-twins": HTTPS_VISSL + "barlow_twins/barlow_twins_32gpus_4node_imagenet1k_1000ep_resnet50.torch", + } + def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained) backbone = model.features @@ -109,10 +88,40 @@ def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Modu type=_type ) - def _fn_resnet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: - model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained) + def _fn_resnet(model_name: str, + pretrained: Union[bool, str] = True, + weights_paths: dict = {"supervised": None}) -> Tuple[nn.Module, int]: + # load according to pretrained if a bool is specified, else set to False + pretrained_flag = (pretrained and isinstance(pretrained, bool)) or (pretrained == "supervised") + + model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained_flag) backbone = nn.Sequential(*list(model.children())[:-2]) num_features = model.fc.in_features + + model_weights = None + if not pretrained_flag and isinstance(pretrained, str): + if pretrained in weights_paths: + device = next(model.parameters()).get_device() + model_weights = load_state_dict_from_url( + weights_paths[pretrained], + map_location=torch.device('cpu') if device is -1 else torch.device(device) + ) + + # add logic here for loading resnet weights from other libraries + if "classy_state_dict" in model_weights.keys(): + model_weights = model_weights["classy_state_dict"]["base_model"]["model"]["trunk"] + model_weights = { + key.replace("_feature_blocks.", "") if "_feature_blocks." in key else key: val + for (key, val) in model_weights.items() + } + else: + raise KeyError('Unrecognized state dict. Logic for loading the current state dict missing.') + else: + raise KeyError( + "Requested weights for {0} not available," + " choose from one of {1}".format(model_name, list(weights_paths.keys())) + ) + return backbone, num_features def _fn_resnet_fpn( @@ -125,14 +134,27 @@ def _fn_resnet_fpn( return backbone, 256 for model_name in RESNET_MODELS: - IMAGE_CLASSIFIER_BACKBONES( - fn=catch_url_error(partial(_fn_resnet, model_name)), + clf_kwargs = dict( + fn=catch_url_error(partial(_fn_resnet, model_name=model_name)), name=model_name, namespace="vision", package="torchvision", - type="resnet" + type="resnet", + weights_paths={"supervised": None} ) + if model_name == 'resnet50': + clf_kwargs.update( + dict( + fn=catch_url_error( + partial(_fn_resnet, model_name=model_name, weights_paths=RESNET50_WEIGHTS_PATHS) + ), + package="multiple", + weights_paths=RESNET50_WEIGHTS_PATHS + ) + ) + IMAGE_CLASSIFIER_BACKBONES(**clf_kwargs) + OBJ_DETECTION_BACKBONES( fn=catch_url_error(partial(_fn_resnet_fpn, model_name)), name=model_name, diff --git a/flash/image/classification/model.py b/flash/image/classification/model.py index 71f6d189ad..ab58b7e66f 100644 --- a/flash/image/classification/model.py +++ b/flash/image/classification/model.py @@ -51,7 +51,8 @@ def fn_resnet(pretrained: bool = True): Args: num_classes: Number of classes to classify. backbone: A string or (model, num_features) tuple to use to compute image features, defaults to ``"resnet18"``. - pretrained: Use a pretrained backbone, defaults to ``True``. + pretrained: A bool or string to specify the pretrained weights of the backbone, defaults to ``True`` + which loads the default supervised pretrained weights. loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`. optimizer: Optimizer to use for training, defaults to :class:`torch.optim.SGD`. metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics` @@ -73,7 +74,7 @@ def __init__( backbone: Union[str, Tuple[nn.Module, int]] = "resnet18", backbone_kwargs: Optional[Dict] = None, head: Optional[Union[FunctionType, nn.Module]] = None, - pretrained: bool = True, + pretrained: Union[bool, str] = True, loss_fn: Optional[Callable] = None, optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, @@ -134,6 +135,16 @@ def forward(self, x) -> torch.Tensor: x = x.mean(-1).mean(-1) return self.head(x) + @classmethod + def available_pretrained_weights(cls, backbone: str): + result = cls.backbones.get(backbone, with_metadata=True) + pretrained_weights = None + + if "weights_paths" in result["metadata"]: + pretrained_weights = list(result["metadata"]["weights_paths"].keys()) + + return pretrained_weights + def _ci_benchmark_fn(self, history: List[Dict[str, Any]]): """ This function is used only for debugging usage with CI diff --git a/tests/image/test_backbones.py b/tests/image/test_backbones.py index 6036927555..bb8ea8791b 100644 --- a/tests/image/test_backbones.py +++ b/tests/image/test_backbones.py @@ -16,15 +16,13 @@ import pytest from pytorch_lightning.utilities import _TORCHVISION_AVAILABLE -from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TIMM_AVAILABLE +from flash.core.utilities.imports import _TIMM_AVAILABLE from flash.image.backbones import catch_url_error, IMAGE_CLASSIFIER_BACKBONES @pytest.mark.parametrize(["backbone", "expected_num_features"], [ pytest.param("resnet34", 512, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), pytest.param("mobilenetv2_100", 1280, marks=pytest.mark.skipif(not _TIMM_AVAILABLE, reason="No timm")), - pytest.param("simclr-imagenet", 2048, marks=pytest.mark.skipif(not _BOLTS_AVAILABLE, reason="No bolts")), - pytest.param("swav-imagenet", 2048, marks=pytest.mark.skipif(not _BOLTS_AVAILABLE, reason="No bolts")), pytest.param("mobilenet_v2", 1280, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), ]) def test_image_classifier_backbones_registry(backbone, expected_num_features): @@ -34,6 +32,21 @@ def test_image_classifier_backbones_registry(backbone, expected_num_features): assert num_features == expected_num_features +@pytest.mark.parametrize(["backbone", "pretrained", "expected_num_features"], [ + pytest.param( + "resnet50", "supervised", 2048, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision") + ), + pytest.param( + "resnet50", "simclr", 2048, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision") + ), +]) +def test_pretrained_weights_registry(backbone, pretrained, expected_num_features): + backbone_fn = IMAGE_CLASSIFIER_BACKBONES.get(backbone) + backbone_model, num_features = backbone_fn(pretrained=pretrained) + assert backbone_model + assert num_features == expected_num_features + + def test_pretrained_backbones_catch_url_error(): def raise_error_if_pretrained(pretrained=False):