Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Pretrained flag and resnet50 pretrained weights (#560)
Browse files Browse the repository at this point in the history
* restructured pretrained weights flag for ImageClassifier

* changelog

* changelog

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* updated PR

* rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* formatting

* Format code with autopep8

* formatting

* formatting

* removed temp code from example

* removed temp code from example

* removed temp code from example

* tests

* Format code with autopep8

* tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: deepsource-autofix[bot] <62050782+deepsource-autofix[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 12, 2021
1 parent 48bdfd8 commit bf1526f
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 46 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added support for (input, target) style datasets (e.g. torchvision) to the from_datasets method ([#552](https://github.com/PyTorchLightning/lightning-flash/pull/552))

- Added support for `from_csv` and `from_data_frame` to `ImageClassificationData` ([#556](https://github.com/PyTorchLightning/lightning-flash/pull/556))

- Added SimCLR, SwAV, Barlow-twins pretrained weights for resnet50 backbone in ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))

- Removed bolts pretrained weights for SSL from ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))

### Deprecated

Expand Down
4 changes: 2 additions & 2 deletions docs/source/template/backbones.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ Here's another example with a slightly more complex model:
:language: python
:pyobject: load_mlp_128_256

Here's a more advanced example, which adds ``SimCLR`` to the ``IMAGE_CLASSIFIER_BACKBONES``, from `flash/image/backbones.py <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/image/backbones.py>`_:
Here's a another example, which adds ``DINO`` pretrained model from PyTorch Hub to the ``IMAGE_CLASSIFIER_BACKBONES``, from `flash/image/backbones.py <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/image/backbones.py>`_:

.. literalinclude:: ../../../flash/image/backbones.py
:language: python
:pyobject: load_simclr_imagenet
:pyobject: dino_vitb16

------

Expand Down
100 changes: 61 additions & 39 deletions flash/image/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import urllib.error
import warnings
from functools import partial
from typing import Tuple
from typing import Tuple, Union

import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn
from torch import nn
from torch.hub import load_state_dict_from_url

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TIMM_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.core.utilities.imports import _TIMM_AVAILABLE, _TORCHVISION_AVAILABLE

if _TIMM_AVAILABLE:
import timm
Expand All @@ -33,21 +31,11 @@
import torchvision
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

if _BOLTS_AVAILABLE:
if os.getenv("WARN_MISSING_PACKAGE") == "0":
with warnings.catch_warnings(record=True) as w:
from pl_bolts.models.self_supervised import SimCLR, SwAV
else:
from pl_bolts.models.self_supervised import SimCLR, SwAV

ROOT_S3_BUCKET = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com"

MOBILENET_MODELS = ["mobilenet_v2"]
VGG_MODELS = ["vgg11", "vgg13", "vgg16", "vgg19"]
RESNET_MODELS = ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d"]
DENSENET_MODELS = ["densenet121", "densenet169", "densenet161"]
TORCHVISION_MODELS = MOBILENET_MODELS + VGG_MODELS + RESNET_MODELS + DENSENET_MODELS
BOLTS_MODELS = ["simclr-imagenet", "swav-imagenet"]

IMAGE_CLASSIFIER_BACKBONES = FlashRegistry("backbones")
OBJ_DETECTION_BACKBONES = FlashRegistry("backbones")
Expand All @@ -71,27 +59,18 @@ def wrapper(*args, pretrained=False, **kwargs):
return wrapper


@IMAGE_CLASSIFIER_BACKBONES(name="simclr-imagenet", namespace="vision", package="bolts")
def load_simclr_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt", **_):
simclr: LightningModule = SimCLR.load_from_checkpoint(path_or_url, strict=False)
# remove the last two layers & turn it into a Sequential model
backbone = nn.Sequential(*list(simclr.encoder.children())[:-2])
return backbone, 2048


@IMAGE_CLASSIFIER_BACKBONES(name="swav-imagenet", namespace="vision", package="bolts")
def load_swav_imagenet(
path_or_url: str = f"{ROOT_S3_BUCKET}/swav/swav_imagenet/swav_imagenet.pth.tar",
**_,
) -> Tuple[nn.Module, int]:
swav: LightningModule = SwAV.load_from_checkpoint(path_or_url, strict=True)
# remove the last two layers & turn it into a Sequential model
backbone = nn.Sequential(*list(swav.model.children())[:-2])
return backbone, 2048


if _TORCHVISION_AVAILABLE:

HTTPS_VISSL = "https://dl.fbaipublicfiles.com/vissl/model_zoo/"
RESNET50_WEIGHTS_PATHS = {
"supervised": None,
"simclr": HTTPS_VISSL + "simclr_rn50_800ep_simclr_8node_resnet_16_07_20.7e8feed1/"
"model_final_checkpoint_phase799.torch",
"swav": HTTPS_VISSL + "swav_in1k_rn50_800ep_swav_8node_resnet_27_07_20.a0a6b676/"
"model_final_checkpoint_phase799.torch",
"barlow-twins": HTTPS_VISSL + "barlow_twins/barlow_twins_32gpus_4node_imagenet1k_1000ep_resnet50.torch",
}

def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained)
backbone = model.features
Expand All @@ -109,10 +88,40 @@ def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Modu
type=_type
)

def _fn_resnet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained)
def _fn_resnet(model_name: str,
pretrained: Union[bool, str] = True,
weights_paths: dict = {"supervised": None}) -> Tuple[nn.Module, int]:
# load according to pretrained if a bool is specified, else set to False
pretrained_flag = (pretrained and isinstance(pretrained, bool)) or (pretrained == "supervised")

model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained_flag)
backbone = nn.Sequential(*list(model.children())[:-2])
num_features = model.fc.in_features

model_weights = None
if not pretrained_flag and isinstance(pretrained, str):
if pretrained in weights_paths:
device = next(model.parameters()).get_device()
model_weights = load_state_dict_from_url(
weights_paths[pretrained],
map_location=torch.device('cpu') if device is -1 else torch.device(device)
)

# add logic here for loading resnet weights from other libraries
if "classy_state_dict" in model_weights.keys():
model_weights = model_weights["classy_state_dict"]["base_model"]["model"]["trunk"]
model_weights = {
key.replace("_feature_blocks.", "") if "_feature_blocks." in key else key: val
for (key, val) in model_weights.items()
}
else:
raise KeyError('Unrecognized state dict. Logic for loading the current state dict missing.')
else:
raise KeyError(
"Requested weights for {0} not available,"
" choose from one of {1}".format(model_name, list(weights_paths.keys()))
)

return backbone, num_features

def _fn_resnet_fpn(
Expand All @@ -125,14 +134,27 @@ def _fn_resnet_fpn(
return backbone, 256

for model_name in RESNET_MODELS:
IMAGE_CLASSIFIER_BACKBONES(
fn=catch_url_error(partial(_fn_resnet, model_name)),
clf_kwargs = dict(
fn=catch_url_error(partial(_fn_resnet, model_name=model_name)),
name=model_name,
namespace="vision",
package="torchvision",
type="resnet"
type="resnet",
weights_paths={"supervised": None}
)

if model_name == 'resnet50':
clf_kwargs.update(
dict(
fn=catch_url_error(
partial(_fn_resnet, model_name=model_name, weights_paths=RESNET50_WEIGHTS_PATHS)
),
package="multiple",
weights_paths=RESNET50_WEIGHTS_PATHS
)
)
IMAGE_CLASSIFIER_BACKBONES(**clf_kwargs)

OBJ_DETECTION_BACKBONES(
fn=catch_url_error(partial(_fn_resnet_fpn, model_name)),
name=model_name,
Expand Down
15 changes: 13 additions & 2 deletions flash/image/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def fn_resnet(pretrained: bool = True):
Args:
num_classes: Number of classes to classify.
backbone: A string or (model, num_features) tuple to use to compute image features, defaults to ``"resnet18"``.
pretrained: Use a pretrained backbone, defaults to ``True``.
pretrained: A bool or string to specify the pretrained weights of the backbone, defaults to ``True``
which loads the default supervised pretrained weights.
loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`.
optimizer: Optimizer to use for training, defaults to :class:`torch.optim.SGD`.
metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics`
Expand All @@ -73,7 +74,7 @@ def __init__(
backbone: Union[str, Tuple[nn.Module, int]] = "resnet18",
backbone_kwargs: Optional[Dict] = None,
head: Optional[Union[FunctionType, nn.Module]] = None,
pretrained: bool = True,
pretrained: Union[bool, str] = True,
loss_fn: Optional[Callable] = None,
optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -134,6 +135,16 @@ def forward(self, x) -> torch.Tensor:
x = x.mean(-1).mean(-1)
return self.head(x)

@classmethod
def available_pretrained_weights(cls, backbone: str):
result = cls.backbones.get(backbone, with_metadata=True)
pretrained_weights = None

if "weights_paths" in result["metadata"]:
pretrained_weights = list(result["metadata"]["weights_paths"].keys())

return pretrained_weights

def _ci_benchmark_fn(self, history: List[Dict[str, Any]]):
"""
This function is used only for debugging usage with CI
Expand Down
19 changes: 16 additions & 3 deletions tests/image/test_backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@
import pytest
from pytorch_lightning.utilities import _TORCHVISION_AVAILABLE

from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TIMM_AVAILABLE
from flash.core.utilities.imports import _TIMM_AVAILABLE
from flash.image.backbones import catch_url_error, IMAGE_CLASSIFIER_BACKBONES


@pytest.mark.parametrize(["backbone", "expected_num_features"], [
pytest.param("resnet34", 512, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")),
pytest.param("mobilenetv2_100", 1280, marks=pytest.mark.skipif(not _TIMM_AVAILABLE, reason="No timm")),
pytest.param("simclr-imagenet", 2048, marks=pytest.mark.skipif(not _BOLTS_AVAILABLE, reason="No bolts")),
pytest.param("swav-imagenet", 2048, marks=pytest.mark.skipif(not _BOLTS_AVAILABLE, reason="No bolts")),
pytest.param("mobilenet_v2", 1280, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")),
])
def test_image_classifier_backbones_registry(backbone, expected_num_features):
Expand All @@ -34,6 +32,21 @@ def test_image_classifier_backbones_registry(backbone, expected_num_features):
assert num_features == expected_num_features


@pytest.mark.parametrize(["backbone", "pretrained", "expected_num_features"], [
pytest.param(
"resnet50", "supervised", 2048, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")
),
pytest.param(
"resnet50", "simclr", 2048, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")
),
])
def test_pretrained_weights_registry(backbone, pretrained, expected_num_features):
backbone_fn = IMAGE_CLASSIFIER_BACKBONES.get(backbone)
backbone_model, num_features = backbone_fn(pretrained=pretrained)
assert backbone_model
assert num_features == expected_num_features


def test_pretrained_backbones_catch_url_error():

def raise_error_if_pretrained(pretrained=False):
Expand Down

0 comments on commit bf1526f

Please sign in to comment.