Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Pretrained flag and resnet50 pretrained weights #560

Merged
merged 23 commits into from
Jul 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
af97ac0
restructured pretrained weights flag for ImageClassifier
ananyahjha93 Jul 9, 2021
c0ee838
changelog
ananyahjha93 Jul 9, 2021
98ed5ca
changelog
ananyahjha93 Jul 9, 2021
8565a49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 9, 2021
9f36733
updated PR
ananyahjha93 Jul 12, 2021
4dda19f
rebase
ananyahjha93 Jul 12, 2021
8f19544
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
681c99b
formatting
ananyahjha93 Jul 12, 2021
d2eb357
Merge branch 'simclr-strategy' of https://github.com/PyTorchLightning…
ananyahjha93 Jul 12, 2021
15754c1
Format code with autopep8
deepsource-autofix[bot] Jul 12, 2021
40b1ebe
formatting
ananyahjha93 Jul 12, 2021
083cb22
formatting
ananyahjha93 Jul 12, 2021
833775c
formatting
ananyahjha93 Jul 12, 2021
7f80ee4
removed temp code from example
ananyahjha93 Jul 12, 2021
55a34ab
removed temp code from example
ananyahjha93 Jul 12, 2021
f89aa4c
removed temp code from example
ananyahjha93 Jul 12, 2021
e4b5755
Merge branch 'master' into simclr-strategy
ananyahjha93 Jul 12, 2021
aee44b9
tests
ananyahjha93 Jul 12, 2021
fd0d361
Merge branch 'simclr-strategy' of https://github.com/PyTorchLightning…
ananyahjha93 Jul 12, 2021
78c4f62
Format code with autopep8
deepsource-autofix[bot] Jul 12, 2021
9b9e5c5
tests
ananyahjha93 Jul 12, 2021
457989f
tests
ananyahjha93 Jul 12, 2021
d15d66b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added support for (input, target) style datasets (e.g. torchvision) to the from_datasets method ([#552](https://github.com/PyTorchLightning/lightning-flash/pull/552))

- Added support for `from_csv` and `from_data_frame` to `ImageClassificationData` ([#556](https://github.com/PyTorchLightning/lightning-flash/pull/556))

- Added SimCLR, SwAV, Barlow-twins pretrained weights for resnet50 backbone in ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Borda marked this conversation as resolved.
Show resolved Hide resolved

- Removed bolts pretrained weights for SSL from ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))

### Deprecated

Expand Down
4 changes: 2 additions & 2 deletions docs/source/template/backbones.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ Here's another example with a slightly more complex model:
:language: python
:pyobject: load_mlp_128_256

Here's a more advanced example, which adds ``SimCLR`` to the ``IMAGE_CLASSIFIER_BACKBONES``, from `flash/image/backbones.py <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/image/backbones.py>`_:
Here's a another example, which adds ``DINO`` pretrained model from PyTorch Hub to the ``IMAGE_CLASSIFIER_BACKBONES``, from `flash/image/backbones.py <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/image/backbones.py>`_:

.. literalinclude:: ../../../flash/image/backbones.py
:language: python
:pyobject: load_simclr_imagenet
:pyobject: dino_vitb16

------

Expand Down
100 changes: 61 additions & 39 deletions flash/image/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import urllib.error
import warnings
from functools import partial
from typing import Tuple
from typing import Tuple, Union

import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn
from torch import nn
from torch.hub import load_state_dict_from_url

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TIMM_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.core.utilities.imports import _TIMM_AVAILABLE, _TORCHVISION_AVAILABLE

if _TIMM_AVAILABLE:
import timm
Expand All @@ -33,21 +31,11 @@
import torchvision
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

if _BOLTS_AVAILABLE:
if os.getenv("WARN_MISSING_PACKAGE") == "0":
with warnings.catch_warnings(record=True) as w:
from pl_bolts.models.self_supervised import SimCLR, SwAV
else:
from pl_bolts.models.self_supervised import SimCLR, SwAV

ROOT_S3_BUCKET = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com"

MOBILENET_MODELS = ["mobilenet_v2"]
VGG_MODELS = ["vgg11", "vgg13", "vgg16", "vgg19"]
RESNET_MODELS = ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d"]
DENSENET_MODELS = ["densenet121", "densenet169", "densenet161"]
TORCHVISION_MODELS = MOBILENET_MODELS + VGG_MODELS + RESNET_MODELS + DENSENET_MODELS
BOLTS_MODELS = ["simclr-imagenet", "swav-imagenet"]

IMAGE_CLASSIFIER_BACKBONES = FlashRegistry("backbones")
OBJ_DETECTION_BACKBONES = FlashRegistry("backbones")
Expand All @@ -71,27 +59,18 @@ def wrapper(*args, pretrained=False, **kwargs):
return wrapper


@IMAGE_CLASSIFIER_BACKBONES(name="simclr-imagenet", namespace="vision", package="bolts")
def load_simclr_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt", **_):
simclr: LightningModule = SimCLR.load_from_checkpoint(path_or_url, strict=False)
# remove the last two layers & turn it into a Sequential model
backbone = nn.Sequential(*list(simclr.encoder.children())[:-2])
return backbone, 2048


@IMAGE_CLASSIFIER_BACKBONES(name="swav-imagenet", namespace="vision", package="bolts")
def load_swav_imagenet(
path_or_url: str = f"{ROOT_S3_BUCKET}/swav/swav_imagenet/swav_imagenet.pth.tar",
**_,
) -> Tuple[nn.Module, int]:
swav: LightningModule = SwAV.load_from_checkpoint(path_or_url, strict=True)
# remove the last two layers & turn it into a Sequential model
backbone = nn.Sequential(*list(swav.model.children())[:-2])
return backbone, 2048


if _TORCHVISION_AVAILABLE:

HTTPS_VISSL = "https://dl.fbaipublicfiles.com/vissl/model_zoo/"
RESNET50_WEIGHTS_PATHS = {
"supervised": None,
"simclr": HTTPS_VISSL + "simclr_rn50_800ep_simclr_8node_resnet_16_07_20.7e8feed1/"
"model_final_checkpoint_phase799.torch",
"swav": HTTPS_VISSL + "swav_in1k_rn50_800ep_swav_8node_resnet_27_07_20.a0a6b676/"
"model_final_checkpoint_phase799.torch",
"barlow-twins": HTTPS_VISSL + "barlow_twins/barlow_twins_32gpus_4node_imagenet1k_1000ep_resnet50.torch",
}

def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained)
backbone = model.features
Expand All @@ -109,10 +88,40 @@ def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Modu
type=_type
)

def _fn_resnet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained)
def _fn_resnet(model_name: str,
pretrained: Union[bool, str] = True,
weights_paths: dict = {"supervised": None}) -> Tuple[nn.Module, int]:
Borda marked this conversation as resolved.
Show resolved Hide resolved
# load according to pretrained if a bool is specified, else set to False
pretrained_flag = (pretrained and isinstance(pretrained, bool)) or (pretrained == "supervised")

model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained_flag)
backbone = nn.Sequential(*list(model.children())[:-2])
num_features = model.fc.in_features

model_weights = None
if not pretrained_flag and isinstance(pretrained, str):
if pretrained in weights_paths:
device = next(model.parameters()).get_device()
model_weights = load_state_dict_from_url(
weights_paths[pretrained],
map_location=torch.device('cpu') if device is -1 else torch.device(device)
)

# add logic here for loading resnet weights from other libraries
if "classy_state_dict" in model_weights.keys():
Borda marked this conversation as resolved.
Show resolved Hide resolved
model_weights = model_weights["classy_state_dict"]["base_model"]["model"]["trunk"]
ananyahjha93 marked this conversation as resolved.
Show resolved Hide resolved
model_weights = {
key.replace("_feature_blocks.", "") if "_feature_blocks." in key else key: val
for (key, val) in model_weights.items()
}
else:
raise KeyError('Unrecognized state dict. Logic for loading the current state dict missing.')
else:
raise KeyError(
"Requested weights for {0} not available,"
" choose from one of {1}".format(model_name, list(weights_paths.keys()))
)

return backbone, num_features

def _fn_resnet_fpn(
Expand All @@ -125,14 +134,27 @@ def _fn_resnet_fpn(
return backbone, 256

for model_name in RESNET_MODELS:
IMAGE_CLASSIFIER_BACKBONES(
fn=catch_url_error(partial(_fn_resnet, model_name)),
clf_kwargs = dict(
fn=catch_url_error(partial(_fn_resnet, model_name=model_name)),
name=model_name,
namespace="vision",
package="torchvision",
type="resnet"
type="resnet",
weights_paths={"supervised": None}
)

if model_name == 'resnet50':
clf_kwargs.update(
dict(
fn=catch_url_error(
partial(_fn_resnet, model_name=model_name, weights_paths=RESNET50_WEIGHTS_PATHS)
),
package="multiple",
weights_paths=RESNET50_WEIGHTS_PATHS
)
)
IMAGE_CLASSIFIER_BACKBONES(**clf_kwargs)

OBJ_DETECTION_BACKBONES(
fn=catch_url_error(partial(_fn_resnet_fpn, model_name)),
name=model_name,
Expand Down
15 changes: 13 additions & 2 deletions flash/image/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def fn_resnet(pretrained: bool = True):
Args:
num_classes: Number of classes to classify.
backbone: A string or (model, num_features) tuple to use to compute image features, defaults to ``"resnet18"``.
pretrained: Use a pretrained backbone, defaults to ``True``.
pretrained: A bool or string to specify the pretrained weights of the backbone, defaults to ``True``
which loads the default supervised pretrained weights.
loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`.
optimizer: Optimizer to use for training, defaults to :class:`torch.optim.SGD`.
metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics`
Expand All @@ -73,7 +74,7 @@ def __init__(
backbone: Union[str, Tuple[nn.Module, int]] = "resnet18",
backbone_kwargs: Optional[Dict] = None,
head: Optional[Union[FunctionType, nn.Module]] = None,
pretrained: bool = True,
pretrained: Union[bool, str] = True,
loss_fn: Optional[Callable] = None,
optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -134,6 +135,16 @@ def forward(self, x) -> torch.Tensor:
x = x.mean(-1).mean(-1)
return self.head(x)

@classmethod
def available_pretrained_weights(cls, backbone: str):
result = cls.backbones.get(backbone, with_metadata=True)
pretrained_weights = None

if "weights_paths" in result["metadata"]:
pretrained_weights = list(result["metadata"]["weights_paths"].keys())

return pretrained_weights

def _ci_benchmark_fn(self, history: List[Dict[str, Any]]):
"""
This function is used only for debugging usage with CI
Expand Down
19 changes: 16 additions & 3 deletions tests/image/test_backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@
import pytest
from pytorch_lightning.utilities import _TORCHVISION_AVAILABLE

from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TIMM_AVAILABLE
from flash.core.utilities.imports import _TIMM_AVAILABLE
from flash.image.backbones import catch_url_error, IMAGE_CLASSIFIER_BACKBONES


@pytest.mark.parametrize(["backbone", "expected_num_features"], [
pytest.param("resnet34", 512, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")),
pytest.param("mobilenetv2_100", 1280, marks=pytest.mark.skipif(not _TIMM_AVAILABLE, reason="No timm")),
pytest.param("simclr-imagenet", 2048, marks=pytest.mark.skipif(not _BOLTS_AVAILABLE, reason="No bolts")),
pytest.param("swav-imagenet", 2048, marks=pytest.mark.skipif(not _BOLTS_AVAILABLE, reason="No bolts")),
pytest.param("mobilenet_v2", 1280, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")),
])
def test_image_classifier_backbones_registry(backbone, expected_num_features):
Expand All @@ -34,6 +32,21 @@ def test_image_classifier_backbones_registry(backbone, expected_num_features):
assert num_features == expected_num_features


@pytest.mark.parametrize(["backbone", "pretrained", "expected_num_features"], [
pytest.param(
"resnet50", "supervised", 2048, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")
),
pytest.param(
"resnet50", "simclr", 2048, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")
),
])
def test_pretrained_weights_registry(backbone, pretrained, expected_num_features):
backbone_fn = IMAGE_CLASSIFIER_BACKBONES.get(backbone)
backbone_model, num_features = backbone_fn(pretrained=pretrained)
assert backbone_model
assert num_features == expected_num_features


def test_pretrained_backbones_catch_url_error():

def raise_error_if_pretrained(pretrained=False):
Expand Down