From d190c7675ccc2198fe954828a50a29c9a96006f2 Mon Sep 17 00:00:00 2001 From: Ananya Harsh Jha Date: Tue, 21 Sep 2021 15:40:54 -0400 Subject: [PATCH] VISSL initial integration (#682) Co-authored-by: Ethan Harris --- .github/workflows/ci-testing.yml | 3 +- CHANGELOG.md | 2 + flash/core/registry.py | 6 +- flash/core/utilities/providers.py | 1 + flash/image/embedding/backbones/__init__.py | 5 + .../embedding/backbones/vissl_backbones.py | 117 ++++++++++ flash/image/embedding/heads/__init__.py | 5 + flash/image/embedding/heads/vissl_heads.py | 147 +++++++++++++ flash/image/embedding/losses/__init__.py | 5 + flash/image/embedding/losses/vissl_losses.py | 171 +++++++++++++++ flash/image/embedding/model.py | 133 +++++------- flash/image/embedding/strategies/__init__.py | 5 + .../embedding/strategies/vissl_strategies.py | 70 ++++++ flash/image/embedding/transforms/__init__.py | 5 + .../embedding/transforms/vissl_transforms.py | 73 +++++++ .../embedding}/vissl/__init__.py | 0 flash/image/embedding/vissl/adapter.py | 200 ++++++++++++++++++ flash/image/embedding/vissl/hooks.py | 94 ++++++++ .../embedding}/vissl/transforms/__init__.py | 5 +- .../embedding}/vissl/transforms/multicrop.py | 0 .../embedding}/vissl/transforms/utilities.py | 2 +- flash_examples/image_embedder.py | 46 +++- .../integrations/vissl/test_transforms.py | 39 +--- tests/image/embedding/test_model.py | 61 ++++-- tests/image/embedding/utils.py | 44 ++++ 25 files changed, 1098 insertions(+), 141 deletions(-) create mode 100644 flash/image/embedding/backbones/__init__.py create mode 100644 flash/image/embedding/backbones/vissl_backbones.py create mode 100644 flash/image/embedding/heads/__init__.py create mode 100644 flash/image/embedding/heads/vissl_heads.py create mode 100644 flash/image/embedding/losses/__init__.py create mode 100644 flash/image/embedding/losses/vissl_losses.py create mode 100644 flash/image/embedding/strategies/__init__.py create mode 100644 flash/image/embedding/strategies/vissl_strategies.py create mode 100644 flash/image/embedding/transforms/__init__.py create mode 100644 flash/image/embedding/transforms/vissl_transforms.py rename flash/{core/integrations => image/embedding}/vissl/__init__.py (100%) create mode 100644 flash/image/embedding/vissl/adapter.py create mode 100644 flash/image/embedding/vissl/hooks.py rename flash/{core/integrations => image/embedding}/vissl/transforms/__init__.py (55%) rename flash/{core/integrations => image/embedding}/vissl/transforms/multicrop.py (100%) rename flash/{core/integrations => image/embedding}/vissl/transforms/utilities.py (97%) create mode 100644 tests/image/embedding/utils.py diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 7b342f370a..606dd9b7d2 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -136,7 +136,8 @@ jobs: - name: Install vissl if: matrix.topic[1] == 'image_extras' run: | - pip install git+https://github.com/facebookresearch/vissl.git@master + pip install git+https://github.com/facebookresearch/ClassyVision.git + pip install git+https://github.com/facebookresearch/vissl.git - name: Install graph test dependencies if: matrix.topic[0] == 'graph' diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f5382e0ee..ad47448567 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support `learn2learn` training_strategy for `ImageClassifier` ([#737](https://github.com/PyTorchLightning/lightning-flash/pull/737)) +- Added `vissl` training_strategies for `ImageEmbedder` ([#682](https://github.com/PyTorchLightning/lightning-flash/pull/682)) + ### Changed - Changed the default `num_workers` on linux to `0` (matching the default for other OS) ([#759](https://github.com/PyTorchLightning/lightning-flash/pull/759)) diff --git a/flash/core/registry.py b/flash/core/registry.py index 641da4e562..a454948e04 100644 --- a/flash/core/registry.py +++ b/flash/core/registry.py @@ -111,7 +111,11 @@ def _register_function( if not callable(fn): raise MisconfigurationException(f"You can only register a callable, found: {fn}") - name = name or fn.__name__ + if name is None: + if hasattr(fn, "func"): + name = fn.func.__name__ + else: + name = fn.__name__ if self._verbose: rank_zero_info(f"Registering: {fn.__name__} function with name: {name} and metadata: {metadata}") diff --git a/flash/core/utilities/providers.py b/flash/core/utilities/providers.py index b4a76516c8..a5bb749246 100644 --- a/flash/core/utilities/providers.py +++ b/flash/core/utilities/providers.py @@ -45,3 +45,4 @@ def __str__(self): _FAIRSEQ = Provider("PyTorch/fairseq", "https://github.com/pytorch/fairseq") _OPEN3D_ML = Provider("Intelligent Systems Lab Org/Open3D-ML", "https://github.com/isl-org/Open3D-ML") _PYTORCHVIDEO = Provider("Facebook Research/PyTorchVideo", "https://github.com/facebookresearch/pytorchvideo") +_VISSL = Provider("Facebook Research/vissl", "https://github.com/facebookresearch/vissl") diff --git a/flash/image/embedding/backbones/__init__.py b/flash/image/embedding/backbones/__init__.py new file mode 100644 index 0000000000..7781040e63 --- /dev/null +++ b/flash/image/embedding/backbones/__init__.py @@ -0,0 +1,5 @@ +from flash.core.registry import FlashRegistry # noqa: F401 +from flash.image.embedding.backbones.vissl_backbones import register_vissl_backbones # noqa: F401 + +IMAGE_EMBEDDER_BACKBONES = FlashRegistry("embedder_backbones") +register_vissl_backbones(IMAGE_EMBEDDER_BACKBONES) diff --git a/flash/image/embedding/backbones/vissl_backbones.py b/flash/image/embedding/backbones/vissl_backbones.py new file mode 100644 index 0000000000..18bb214efe --- /dev/null +++ b/flash/image/embedding/backbones/vissl_backbones.py @@ -0,0 +1,117 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch.nn as nn + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _VISSL_AVAILABLE + +if _VISSL_AVAILABLE: + from vissl.config.attr_dict import AttrDict + from vissl.models.model_helpers import RESNET_NORM_LAYER + from vissl.models.trunks import MODEL_TRUNKS_REGISTRY + + from flash.image.embedding.vissl.adapter import VISSLAdapter +else: + RESNET_NORM_LAYER = object + + +def vision_transformer( + image_size: int = 224, + patch_size: int = 16, + hidden_dim: int = 384, + num_layers: int = 12, + num_heads: int = 6, + mlp_dim: int = 1532, + dropout_rate: float = 0, + attention_dropout_rate: float = 0, + drop_path_rate: float = 0, + qkv_bias: bool = True, + qk_scale: bool = False, + classifier: str = "token", + **kwargs, +) -> nn.Module: + + cfg = VISSLAdapter.get_model_config_template() + cfg.TRUNK = AttrDict( + { + "NAME": "vision_transformer", + "VISION_TRANSFORMERS": AttrDict( + { + "IMAGE_SIZE": image_size, + "PATCH_SIZE": patch_size, + "HIDDEN_DIM": hidden_dim, + "NUM_LAYERS": num_layers, + "NUM_HEADS": num_heads, + "MLP_DIM": mlp_dim, + "DROPOUT_RATE": dropout_rate, + "ATTENTION_DROPOUT_RATE": attention_dropout_rate, + "DROP_PATH_RATE": drop_path_rate, + "QKV_BIAS": qkv_bias, + "QK_SCALE": qk_scale, + "CLASSIFIER": classifier, + } + ), + } + ) + + trunk = MODEL_TRUNKS_REGISTRY["vision_transformer"](cfg, model_name="vision_transformer") + trunk.model_config = cfg + + return trunk, trunk.num_features + + +def resnet( + depth: int = 50, + width_multiplier: int = 1, + norm: RESNET_NORM_LAYER = None, + groupnorm_groups: int = 32, + standardize_convolutions: bool = False, + groups: int = 1, + zero_init_residual: bool = False, + width_per_group: int = 64, + layer4_stride: int = 2, + **kwargs, +) -> nn.Module: + if norm is None: + norm = RESNET_NORM_LAYER.BatchNorm + cfg = VISSLAdapter.get_model_config_template() + cfg.TRUNK = AttrDict( + { + "NAME": "resnet", + "RESNETS": AttrDict( + { + "DEPTH": depth, + "WIDTH_MULTIPLIER": width_multiplier, + "NORM": norm, + "GROUPNORM_GROUPS": groupnorm_groups, + "STANDARDIZE_CONVOLUTIONS": standardize_convolutions, + "GROUPS": groups, + "ZERO_INIT_RESIDUAL": zero_init_residual, + "WIDTH_PER_GROUP": width_per_group, + "LAYER4_STRIDE": layer4_stride, + } + ), + } + ) + + trunk = MODEL_TRUNKS_REGISTRY["resnet"](cfg, model_name="resnet") + trunk.model_config = cfg + + return trunk, 2048 + + +def register_vissl_backbones(register: FlashRegistry): + if _VISSL_AVAILABLE: + for backbone in (vision_transformer, resnet): + register(backbone) diff --git a/flash/image/embedding/heads/__init__.py b/flash/image/embedding/heads/__init__.py new file mode 100644 index 0000000000..0afd7bc39d --- /dev/null +++ b/flash/image/embedding/heads/__init__.py @@ -0,0 +1,5 @@ +from flash.core.registry import FlashRegistry # noqa: F401 +from flash.image.embedding.heads.vissl_heads import register_vissl_heads # noqa: F401 + +IMAGE_EMBEDDER_HEADS = FlashRegistry("embedder_heads") +register_vissl_heads(IMAGE_EMBEDDER_HEADS) diff --git a/flash/image/embedding/heads/vissl_heads.py b/flash/image/embedding/heads/vissl_heads.py new file mode 100644 index 0000000000..0eca1456ec --- /dev/null +++ b/flash/image/embedding/heads/vissl_heads.py @@ -0,0 +1,147 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Union + +import torch +import torch.nn as nn + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _VISSL_AVAILABLE + +if _VISSL_AVAILABLE: + from vissl.config.attr_dict import AttrDict + from vissl.models.heads import MODEL_HEADS_REGISTRY, register_model_head + + from flash.image.embedding.vissl.adapter import VISSLAdapter +else: + AttrDict = object + + +class SimCLRHead(nn.Module): + def __init__( + self, + model_config: AttrDict, + dims: List[int] = [2048, 2048, 128], + use_bn: bool = True, + **kwargs, + ) -> nn.Module: + super().__init__() + + self.model_config = model_config + self.dims = dims + self.use_bn = use_bn + + self.clf = self.create_mlp() + + def create_mlp(self): + layers = [] + last_dim = self.dims[0] + + for dim in self.dims[1:-1]: + layers.append(nn.Linear(last_dim, dim)) + + if self.use_bn: + layers.append( + nn.BatchNorm1d( + dim, + eps=self.model_config.HEAD.BATCHNORM_EPS, + momentum=self.model_config.HEAD.BATCHNORM_MOMENTUM, + ) + ) + + layers.append(nn.ReLU(inplace=True)) + last_dim = dim + + layers.append(nn.Linear(last_dim, self.dims[-1])) + return nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.clf(x) + + +if _VISSL_AVAILABLE: + SimCLRHead = register_model_head("simclr_head")(SimCLRHead) + + +def simclr_head( + dims: List[int] = [2048, 2048, 128], + use_bn: bool = True, + **kwargs, +) -> nn.Module: + cfg = VISSLAdapter.get_model_config_template() + head_kwargs = { + "dims": dims, + "use_bn": use_bn, + } + + cfg.HEAD.PARAMS.append(["simclr_head", head_kwargs]) + + head = MODEL_HEADS_REGISTRY["simclr_head"](cfg, **head_kwargs) + head.model_config = cfg + + return head + + +def swav_head( + dims: List[int] = [2048, 2048, 128], + use_bn: bool = True, + num_clusters: Union[int, List[int]] = [3000], + use_bias: bool = True, + return_embeddings: bool = True, + skip_last_bn: bool = True, + normalize_feats: bool = True, + activation_name: str = "ReLU", + use_weight_norm_prototypes: bool = False, + **kwargs, +) -> nn.Module: + cfg = VISSLAdapter.get_model_config_template() + head_kwargs = { + "dims": dims, + "use_bn": use_bn, + "num_clusters": [num_clusters] if isinstance(num_clusters, int) else num_clusters, + "use_bias": use_bias, + "return_embeddings": return_embeddings, + "skip_last_bn": skip_last_bn, + "normalize_feats": normalize_feats, + "activation_name": activation_name, + "use_weight_norm_prototypes": use_weight_norm_prototypes, + } + + cfg.HEAD.PARAMS.append(["swav_head", head_kwargs]) + + head = MODEL_HEADS_REGISTRY["swav_head"](cfg, **head_kwargs) + head.model_config = cfg + + return head + + +def barlow_twins_head(**kwargs) -> nn.Module: + return simclr_head(dims=[2048, 8192, 8192, 8192], **kwargs) + + +def dino_head(**kwargs) -> nn.Module: + return swav_head( + dims=[384, 2048, 2048, 256], + use_bn=False, + return_embeddings=False, + activation_name="GELU", + num_clusters=[65536], + use_weight_norm_prototypes=True, + **kwargs, + ) + + +def register_vissl_heads(register: FlashRegistry): + for ssl_head in (swav_head, simclr_head, dino_head, barlow_twins_head): + register(ssl_head) diff --git a/flash/image/embedding/losses/__init__.py b/flash/image/embedding/losses/__init__.py new file mode 100644 index 0000000000..71c0717e21 --- /dev/null +++ b/flash/image/embedding/losses/__init__.py @@ -0,0 +1,5 @@ +from flash.core.registry import FlashRegistry # noqa: F401 +from flash.image.embedding.losses.vissl_losses import register_vissl_losses # noqa: F401 + +IMAGE_EMBEDDER_LOSS_FUNCTIONS = FlashRegistry("embedder_losses") +register_vissl_losses(IMAGE_EMBEDDER_LOSS_FUNCTIONS) diff --git a/flash/image/embedding/losses/vissl_losses.py b/flash/image/embedding/losses/vissl_losses.py new file mode 100644 index 0000000000..87dcf5260c --- /dev/null +++ b/flash/image/embedding/losses/vissl_losses.py @@ -0,0 +1,171 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Union + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _VISSL_AVAILABLE + +if _VISSL_AVAILABLE: + import vissl.losses # noqa: F401 + from classy_vision.losses import ClassyLoss, LOSS_REGISTRY + from vissl.config.attr_dict import AttrDict +else: + AttrDict = object + ClassyLoss = object + + +def get_loss_fn(loss_name: str, cfg: AttrDict): + loss_fn = LOSS_REGISTRY[loss_name](cfg) + loss_fn.__dict__["loss_name"] = loss_name + + return loss_fn + + +def dino_loss( + num_crops: int = 10, + momentum: float = 0.996, + student_temp: float = 0.1, + teacher_temp_min: float = 0.04, + teacher_temp_max: float = 0.07, + teacher_temp_warmup_iters: int = 37530, # convert this to 30 epochs + crops_for_teacher: List[int] = [0, 1], + ema_center: float = 0.9, + normalize_last_layer: bool = False, + output_dim: int = 65536, + **kwargs, +) -> ClassyLoss: + loss_name = "dino_loss" + cfg = AttrDict( + { + "num_crops": num_crops, + "momentum": momentum, + "student_temp": student_temp, + "teacher_temp_min": teacher_temp_min, + "teacher_temp_max": teacher_temp_max, + "teacher_temp_warmup_iters": teacher_temp_warmup_iters, + "crops_for_teacher": crops_for_teacher, + "ema_center": ema_center, + "normalize_last_layer": normalize_last_layer, + "output_dim": output_dim, + } + ) + + return get_loss_fn(loss_name, cfg) + + +def swav_loss( + embedding_dim: int = 128, + temperature: float = 0.1, + use_double_precision: bool = False, + normalize_last_layer: bool = True, + num_iters: int = 3, + epsilon: float = 0.05, + num_crops: int = 8, + crops_for_assign: List[int] = [0, 1], + num_prototypes: Union[int, List[int]] = 3000, + temp_hard_assignment_iters: int = 0, + output_dir: str = ".", + queue_length: int = 0, + start_iter: int = 0, + local_queue_length: int = 0, +) -> ClassyLoss: + loss_name = "swav_loss" + cfg = AttrDict( + { + "embedding_dim": embedding_dim, + "temperature": temperature, + "use_double_precision": use_double_precision, + "normalize_last_layer": normalize_last_layer, + "num_iters": num_iters, + "epsilon": epsilon, + "num_crops": num_crops, + "crops_for_assign": crops_for_assign, + "num_prototypes": [num_prototypes] if isinstance(num_prototypes, int) else num_prototypes, + "temp_hard_assignment_iters": temp_hard_assignment_iters, + "output_dir": output_dir, + "queue": AttrDict( + { + "queue_length": queue_length, + "start_iter": start_iter, + "local_queue_length": local_queue_length, + } + ), + } + ) + + return get_loss_fn(loss_name, cfg) + + +def barlow_twins_loss( + lambda_: float = 0.0051, scale_loss: float = 0.024, latent_embedding_dim: int = 8192 +) -> ClassyLoss: + loss_name = "barlow_twins_loss" + cfg = AttrDict( + { + "lambda_": lambda_, + "scale_loss": scale_loss, + "embedding_dim": latent_embedding_dim, + } + ) + + return get_loss_fn(loss_name, cfg) + + +def simclr_loss( + temperature: float = 0.1, + embedding_dim: int = 128, + effective_batch_size: int = 1, # set by setup training hook + world_size: int = 1, # set by setup training hook +) -> ClassyLoss: + loss_name = "simclr_info_nce_loss" + cfg = AttrDict( + { + "temperature": temperature, + "buffer_params": AttrDict( + { + "world_size": world_size, + "embedding_dim": embedding_dim, + "effective_batch_size": effective_batch_size, + } + ), + } + ) + + return get_loss_fn(loss_name, cfg) + + +def moco_loss( + embedding_dim: int = 128, + queue_size: int = 65536, + momentum: float = 0.999, + temperature: int = 0.2, + shuffle_batch: bool = True, +) -> ClassyLoss: + loss_name = "moco_loss" + cfg = AttrDict( + { + "embedding_dim": embedding_dim, + "queue_size": queue_size, + "momentum": momentum, + "temperature": temperature, + "shuffle_batch": shuffle_batch, + } + ) + + return get_loss_fn(loss_name, cfg) + + +def register_vissl_losses(register: FlashRegistry): + for loss_fn in (dino_loss, swav_loss, barlow_twins_loss, simclr_loss, moco_loss): + register(loss_fn) diff --git a/flash/image/embedding/model.py b/flash/image/embedding/model.py index c803757ec5..2966ef2ce0 100644 --- a/flash/image/embedding/model.py +++ b/flash/image/embedding/model.py @@ -11,29 +11,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Type, Union import torch -from pytorch_lightning.utilities import rank_zero_warn -from torch import nn -from torch.nn import functional as F from torch.optim.lr_scheduler import _LRScheduler -from torchmetrics import Accuracy, Metric -from flash.core.data.data_source import DefaultDataKeys -from flash.core.model import Task +from flash.core.adapter import AdapterTask from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _IMAGE_AVAILABLE -from flash.core.utilities.isinstance import _isinstance -from flash.image.classification.data import ImageClassificationPreprocess +from flash.core.utilities.imports import _VISSL_AVAILABLE -if _IMAGE_AVAILABLE: - from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES +if _VISSL_AVAILABLE: + import classy_vision + import classy_vision.generic.distributed_util + + from flash.image.embedding.backbones import IMAGE_EMBEDDER_BACKBONES + from flash.image.embedding.strategies import IMAGE_EMBEDDER_STRATEGIES + + # patch this to avoid classy vision/vissl based distributed training + classy_vision.generic.distributed_util.get_world_size = lambda: 1 else: - IMAGE_CLASSIFIER_BACKBONES = FlashRegistry("backbones") + IMAGE_EMBEDDER_BACKBONES = FlashRegistry("backbones") + IMAGE_EMBEDDER_STRATEGIES = FlashRegistry("embedder_training_strategies") -class ImageEmbedder(Task): +class ImageEmbedder(AdapterTask): """The ``ImageEmbedder`` is a :class:`~flash.Task` for obtaining feature vectors (embeddings) from images. For more details, see :ref:`image_embedder`. @@ -54,87 +55,63 @@ class ImageEmbedder(Task): pooling_fn: Function used to pool image to generate embeddings, defaults to :func:`torch.max`. """ - backbones: FlashRegistry = IMAGE_CLASSIFIER_BACKBONES + training_strategies: FlashRegistry = IMAGE_EMBEDDER_STRATEGIES + backbones: FlashRegistry = IMAGE_EMBEDDER_BACKBONES required_extras: str = "image" def __init__( self, - embedding_dim: Optional[int] = None, - backbone: str = "resnet101", + training_strategy: str, + embedding_dim: int = 128, + backbone: str = "resnet", pretrained: bool = True, - loss_fn: Callable = F.cross_entropy, optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, - metrics: Union[Metric, Callable, Mapping, Sequence, None] = (Accuracy()), learning_rate: float = 1e-3, - pooling_fn: Callable = torch.max, + **kwargs: Any, ): - super().__init__( - model=None, + self.save_hyperparameters() + + backbone, num_features = self.backbones.get(backbone)(**kwargs, pretrained=pretrained) + + # TODO: add linear layer to backbone to get num_feature -> embedding_dim before applying heads + # assert embedding_dim == num_features + + metadata = self.training_strategies.get(training_strategy, with_metadata=True) + loss_fn, head, hooks = metadata["fn"](**kwargs) + + adapter = metadata["metadata"]["adapter"].from_task( + self, loss_fn=loss_fn, + backbone=backbone, + head=head, + hooks=hooks, + ) + + super().__init__( + adapter=adapter, optimizer=optimizer, optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, - metrics=metrics, learning_rate=learning_rate, - preprocess=ImageClassificationPreprocess(), ) - self.save_hyperparameters() - self.backbone_name = backbone - self.embedding_dim = embedding_dim - assert pooling_fn in [torch.mean, torch.max] - self.pooling_fn = pooling_fn - - self.backbone, num_features = self.backbones.get(backbone)(pretrained=pretrained) - - if embedding_dim is None: - self.head = nn.Identity() - else: - self.head = nn.Sequential( - nn.Flatten(), - nn.Linear(num_features, embedding_dim), - ) - rank_zero_warn("Adding linear layer on top of backbone. Remember to finetune first before using!") - - def apply_pool(self, x): - x = self.pooling_fn(x, dim=-1) - if _isinstance(x, Tuple[torch.Tensor, torch.Tensor]): - x = x[0] - x = self.pooling_fn(x, dim=-1) - if _isinstance(x, Tuple[torch.Tensor, torch.Tensor]): - x = x[0] - return x - - def forward(self, x) -> torch.Tensor: - x = self.backbone(x) - - # bolts ssl models return lists - if isinstance(x, tuple): - x = x[-1] - - if x.dim() == 4 and not self.embedding_dim: - x = self.apply_pool(x) - - x = self.head(x) - return x - - def training_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) - return super().training_step(batch, batch_idx) - - def validation_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) - return super().validation_step(batch, batch_idx) - - def test_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) - return super().test_step(batch, batch_idx) - - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - batch = batch[DefaultDataKeys.INPUT] - return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) + def on_train_start(self) -> None: + self.adapter.on_train_start() + + def on_train_epoch_end(self) -> None: + self.adapter.on_train_epoch_end() + + def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + self.adapter.on_train_batch_end(outputs, batch, batch_idx, dataloader_idx) + + @classmethod + def available_training_strategies(cls) -> List[str]: + registry: Optional[FlashRegistry] = getattr(cls, "training_strategies", None) + if registry is None: + return [] + return registry.available_keys() diff --git a/flash/image/embedding/strategies/__init__.py b/flash/image/embedding/strategies/__init__.py new file mode 100644 index 0000000000..8d010d7bb8 --- /dev/null +++ b/flash/image/embedding/strategies/__init__.py @@ -0,0 +1,5 @@ +from flash.core.registry import FlashRegistry # noqa: F401 +from flash.image.embedding.strategies.vissl_strategies import register_vissl_strategies # noqa: F401 + +IMAGE_EMBEDDER_STRATEGIES = FlashRegistry("embedder_training_strategies") +register_vissl_strategies(IMAGE_EMBEDDER_STRATEGIES) diff --git a/flash/image/embedding/strategies/vissl_strategies.py b/flash/image/embedding/strategies/vissl_strategies.py new file mode 100644 index 0000000000..2622d7ae5b --- /dev/null +++ b/flash/image/embedding/strategies/vissl_strategies.py @@ -0,0 +1,70 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _VISSL_AVAILABLE +from flash.core.utilities.providers import _VISSL +from flash.image.embedding.heads import IMAGE_EMBEDDER_HEADS +from flash.image.embedding.losses import IMAGE_EMBEDDER_LOSS_FUNCTIONS +from flash.image.embedding.vissl.adapter import VISSLAdapter +from flash.image.embedding.vissl.hooks import SimCLRTrainingSetupHook, TrainingSetupHook + +if _VISSL_AVAILABLE: + from vissl.hooks.dino_hooks import DINOHook + from vissl.hooks.moco_hooks import MoCoHook + from vissl.hooks.swav_hooks import NormalizePrototypesHook, SwAVUpdateQueueScoresHook + + +def dino(head: str = "dino_head", **kwargs): + loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("dino_loss")(**kwargs) + head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs) + + return loss_fn, head, [DINOHook(), TrainingSetupHook()] + + +def swav(head: str = "swav_head", **kwargs): + loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("swav_loss")(**kwargs) + head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs) + + return loss_fn, head, [SwAVUpdateQueueScoresHook(), NormalizePrototypesHook(), TrainingSetupHook()] + + +def simclr(head: str = "simclr_head", **kwargs): + loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("simclr_loss")(**kwargs) + head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs) + + return loss_fn, head, [SimCLRTrainingSetupHook()] + + +def moco(head: str = "simclr_head", **kwargs): + loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("moco_loss")(**kwargs) + head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs) + + return ( + loss_fn, + head, + [MoCoHook(loss_fn.loss_config.momentum, loss_fn.loss_config.shuffle_batch), TrainingSetupHook()], + ) + + +def barlow_twins(head: str = "barlow_twins_head", **kwargs): + loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("barlow_twins_loss")(**kwargs) + head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs) + + return loss_fn, head, [TrainingSetupHook()] + + +def register_vissl_strategies(register: FlashRegistry): + if _VISSL_AVAILABLE: + for training_strategy in (dino, swav, simclr, moco, barlow_twins): + register(training_strategy, name=training_strategy.__name__, adapter=VISSLAdapter, providers=_VISSL) diff --git a/flash/image/embedding/transforms/__init__.py b/flash/image/embedding/transforms/__init__.py new file mode 100644 index 0000000000..79657f9491 --- /dev/null +++ b/flash/image/embedding/transforms/__init__.py @@ -0,0 +1,5 @@ +from flash.core.registry import FlashRegistry # noqa: F401 +from flash.image.embedding.transforms.vissl_transforms import register_vissl_transforms # noqa: F401 + +IMAGE_EMBEDDER_TRANSFORMS = FlashRegistry("embedder_transforms") +register_vissl_transforms(IMAGE_EMBEDDER_TRANSFORMS) diff --git a/flash/image/embedding/transforms/vissl_transforms.py b/flash/image/embedding/transforms/vissl_transforms.py new file mode 100644 index 0000000000..6635aa6b2a --- /dev/null +++ b/flash/image/embedding/transforms/vissl_transforms.py @@ -0,0 +1,73 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Sequence + +import torch.nn as nn + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _VISSL_AVAILABLE + +if _VISSL_AVAILABLE: + from classy_vision.dataset.transforms import TRANSFORM_REGISTRY + + +def simclr_transform( + total_num_crops: int = 2, + num_crops: Sequence[int] = [2], + size_crops: Sequence[int] = [224], + crop_scales: Sequence[Sequence[float]] = [[0.4, 1]], + gaussian_blur: bool = True, + jitter_strength: float = 1.0, + normalize: Optional[nn.Module] = None, +) -> nn.Module: + """For simclr, barlow twins and moco.""" + transform = TRANSFORM_REGISTRY["multicrop_ssl_transform"]( + total_num_crops=total_num_crops, + num_crops=num_crops, + size_crops=size_crops, + crop_scales=crop_scales, + gaussian_blur=gaussian_blur, + jitter_strength=jitter_strength, + normalize=normalize, + ) + + return transform + + +def swav_transform( + total_num_crops: int = 8, + num_crops: Sequence[int] = [2, 6], + size_crops: Sequence[int] = [224, 96], + crop_scales: Sequence[Sequence[float]] = [[0.4, 1], [0.05, 0.4]], + gaussian_blur: bool = True, + jitter_strength: float = 1.0, + normalize: Optional[nn.Module] = None, +) -> nn.Module: + """For swav and dino.""" + transform = TRANSFORM_REGISTRY["multicrop_ssl_transform"]( + total_num_crops=total_num_crops, + num_crops=num_crops, + size_crops=size_crops, + crop_scales=crop_scales, + gaussian_blur=gaussian_blur, + jitter_strength=jitter_strength, + normalize=normalize, + ) + + return transform + + +def register_vissl_transforms(register: FlashRegistry): + for transform in (simclr_transform, swav_transform): + register(transform) diff --git a/flash/core/integrations/vissl/__init__.py b/flash/image/embedding/vissl/__init__.py similarity index 100% rename from flash/core/integrations/vissl/__init__.py rename to flash/image/embedding/vissl/__init__.py diff --git a/flash/image/embedding/vissl/adapter.py b/flash/image/embedding/vissl/adapter.py new file mode 100644 index 0000000000..1a1317ce91 --- /dev/null +++ b/flash/image/embedding/vissl/adapter.py @@ -0,0 +1,200 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Union + +import torch +import torch.nn as nn + +from flash.core.adapter import Adapter +from flash.core.data.data_source import DefaultDataKeys +from flash.core.model import Task +from flash.core.utilities.imports import _VISSL_AVAILABLE +from flash.image.embedding.vissl.hooks import AdaptVISSLHooks + +if _VISSL_AVAILABLE: + from classy_vision.hooks.classy_hook import ClassyHook + from classy_vision.losses import ClassyLoss + from vissl.config.attr_dict import AttrDict + from vissl.models.base_ssl_model import BaseSSLMultiInputOutputModel +else: + ClassyLoss = object + ClassyHook = object + + +class MockVISSLTask: + def __init__(self, vissl_adapter, vissl_loss, task_config, vissl_model) -> None: + self.vissl_adapter = vissl_adapter + self.loss = vissl_loss + self.config = task_config + self.base_model = vissl_model + self.model = self.base_model # set by property in ClassyTask + + # set using trainingsetuphook + self.device = None + + self.iteration = 0 + self.max_iteration = 1 # set by training setup hook + + # set for momentum teacher based hooks + self.last_batch = AttrDict({"sample": AttrDict({"input": None})}) + + +class VISSLAdapter(Adapter, AdaptVISSLHooks): + """The ``VISSLAdapter`` is an :class:`~flash.core.adapter.Adapter` for integrating with VISSL.""" + + required_extras: str = "image" + + def __init__( + self, + backbone: nn.Module, + head: nn.Module, + loss_fn: ClassyLoss, + hooks: List[ClassyHook], + ) -> None: + + Adapter.__init__(self) + + self.model_config = self.get_model_config_template() + self.optimizer_config = AttrDict({}) + + self.backbone = backbone + self.head = [head] if not isinstance(head, list) else head + self.loss_fn = loss_fn + self.hooks = hooks + + self.model_config.TRUNK = self.backbone.model_config.TRUNK + self.model_config.HEAD = self.head[0].model_config.HEAD + self.task_config = AttrDict( + { + "MODEL": self.model_config, + "OPTIMIZER": self.optimizer_config, + "LOSS": AttrDict( + { + "name": self.loss_fn.loss_name, + self.loss_fn.loss_name: self.loss_fn.loss_config, + } + ), + } + ) + + self.vissl_base_model = BaseSSLMultiInputOutputModel(self.model_config, self.optimizer_config) + # patch backbone and head + self.vissl_base_model.trunk = backbone + self.vissl_base_model.heads = nn.ModuleList(self.head) + + self.vissl_task = MockVISSLTask(self, self.loss_fn, self.task_config, self.vissl_base_model) + + AdaptVISSLHooks.__init__(self, hooks=hooks, task=self.vissl_task) + + @classmethod + def from_task( + cls, + task: Task, + loss_fn: ClassyLoss, + backbone: nn.Module, + head: Union[nn.Module, List[nn.Module]], + hooks: List[ClassyHook], + ) -> Adapter: + result = cls( + backbone=backbone, + head=head, + loss_fn=loss_fn, + hooks=hooks, + ) + + result.__dict__["adapter_task"] = task + + return result + + @staticmethod + def get_model_config_template(): + cfg = AttrDict( + { + "SINGLE_PASS_EVERY_CROP": False, + "INPUT_TYPE": "rgb", + "MULTI_INPUT_HEAD_MAPPING": [], + "TRUNK": AttrDict({}), + "HEAD": AttrDict( + { + "PARAMS": [], + "BATCHNORM_EPS": 1e-5, + "BATCHNORM_MOMENTUM": 0.1, + "PARAMS_MULTIPLIER": 1.0, + } + ), + "FEATURE_EVAL_SETTINGS": AttrDict( + { + "EVAL_MODE_ON": False, + "EXTRACT_TRUNK_FEATURES_ONLY": False, + } + ), + "_MODEL_INIT_SEED": 0, + "ACTIVATION_CHECKPOINTING": AttrDict( + { + "USE_ACTIVATION_CHECKPOINTING": False, + "NUM_ACTIVATION_CHECKPOINTING_SPLITS": 2, + } + ), + } + ) + + return cfg + + def forward(self, batch: torch.Tensor) -> Any: + return self.vissl_base_model.trunk(batch, [])[0] + + def ssl_forward(self, batch) -> Any: + model_output = self.vissl_base_model(batch) + + # vissl-specific + if len(model_output) == 1: + model_output = model_output[0] + + return model_output + + def training_step(self, batch: Any, batch_idx: int) -> Any: + out = self.ssl_forward(batch[DefaultDataKeys.INPUT]) + self.task.last_batch["sample"]["input"] = batch[DefaultDataKeys.INPUT] + + # call forward hook from VISSL (momentum updates) + for hook in self.hooks: + hook.on_forward(self.vissl_task) + + loss = self.loss_fn(out, target=None) + self.adapter_task.log_dict({"train_loss": loss.item()}) + + return loss + + def validation_step(self, batch: Any, batch_idx: int) -> None: + out = self.ssl_forward(batch[DefaultDataKeys.INPUT]) + self.task.last_batch["sample"]["input"] = batch[DefaultDataKeys.INPUT] + + loss = self.loss_fn(out, target=None) + self.adapter_task.log_dict({"val_loss": loss}) + + return loss + + def test_step(self, batch: Any, batch_idx: int) -> None: + out = self.ssl_forward(batch[DefaultDataKeys.INPUT]) + self.task.last_batch["sample"]["input"] = batch[DefaultDataKeys.INPUT] + + loss = self.loss_fn(out, target=None) + self.adapter_task.log_dict({"test_loss": loss}) + + return loss + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + input_image = batch[DefaultDataKeys.INPUT] + + return self(input_image) diff --git a/flash/image/embedding/vissl/hooks.py b/flash/image/embedding/vissl/hooks.py new file mode 100644 index 0000000000..efb16cf1e6 --- /dev/null +++ b/flash/image/embedding/vissl/hooks.py @@ -0,0 +1,94 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List + +import torch +from pytorch_lightning.core.hooks import ModelHooks + +import flash +from flash.core.utilities.imports import _VISSL_AVAILABLE + +if _VISSL_AVAILABLE: + from classy_vision.hooks.classy_hook import ClassyHook +else: + + class ClassyHook: + _noop = object + + +class TrainingSetupHook(ClassyHook): + on_start = ClassyHook._noop + on_phase_start = ClassyHook._noop + on_loss_and_meter = ClassyHook._noop + on_backward = ClassyHook._noop + on_step = ClassyHook._noop + on_phase_end = ClassyHook._noop + on_end = ClassyHook._noop + on_update = ClassyHook._noop + on_forward = ClassyHook._noop + + def __init__(self): + super().__init__() + + @torch.no_grad() + def on_start(self, task: "flash.image.embedding.vissl.adapter.MockVISSLTask") -> None: + lightning_module = task.vissl_adapter.adapter_task + task.device = lightning_module.device + + num_nodes = lightning_module.trainer.num_nodes + accelerators_ids = lightning_module.trainer.accelerator_connector.parallel_device_ids + accelerator_per_node = len(accelerators_ids) if accelerators_ids is not None else 1 + task.world_size = num_nodes * accelerator_per_node + + if lightning_module.trainer.max_epochs is None: + lightning_module.trainer.max_epochs = 1 + + task.max_iteration = lightning_module.trainer.max_epochs * lightning_module.trainer.num_training_batches + + +class SimCLRTrainingSetupHook(TrainingSetupHook): + def __init__(self): + super().__init__() + + @torch.no_grad() + def on_start(self, task: "flash.image.embedding.vissl.adapter.MockVISSLTask") -> None: + super().on_start(task) + + lightning_module = task.vissl_adapter.adapter_task + + task.loss.info_criterion.buffer_params.effective_batch_size = ( + task.world_size * 2 * lightning_module.trainer.datamodule.batch_size + ) + task.loss.info_criterion.buffer_params.world_size = task.world_size + + task.loss.info_criterion.precompute_pos_neg_mask() + + +class AdaptVISSLHooks(ModelHooks): + def __init__(self, hooks: List[ClassyHook], task) -> None: + super().__init__() + + self.hooks = hooks + self.task = task + + def on_train_start(self) -> None: + for hook in self.hooks: + hook.on_start(self.task) + + def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + self.task.iteration += 1 + + def on_train_epoch_end(self) -> None: + for hook in self.hooks: + hook.on_update(self.task) diff --git a/flash/core/integrations/vissl/transforms/__init__.py b/flash/image/embedding/vissl/transforms/__init__.py similarity index 55% rename from flash/core/integrations/vissl/transforms/__init__.py rename to flash/image/embedding/vissl/transforms/__init__.py index 804689456e..85179afb7d 100644 --- a/flash/core/integrations/vissl/transforms/__init__.py +++ b/flash/image/embedding/vissl/transforms/__init__.py @@ -1,9 +1,8 @@ from flash.core.utilities.imports import _VISSL_AVAILABLE # noqa: F401 +from flash.image.embedding.vissl.transforms.multicrop import StandardMultiCropSSLTransform # noqa: F401 +from flash.image.embedding.vissl.transforms.utilities import multicrop_collate_fn # noqa: F401 if _VISSL_AVAILABLE: from classy_vision.dataset.transforms import register_transform # noqa: F401 - from flash.core.integrations.vissl.transforms.multicrop import StandardMultiCropSSLTransform # noqa: F401 - from flash.core.integrations.vissl.transforms.utilities import vissl_collate_fn # noqa: F401 - register_transform("multicrop_ssl_transform")(StandardMultiCropSSLTransform) diff --git a/flash/core/integrations/vissl/transforms/multicrop.py b/flash/image/embedding/vissl/transforms/multicrop.py similarity index 100% rename from flash/core/integrations/vissl/transforms/multicrop.py rename to flash/image/embedding/vissl/transforms/multicrop.py diff --git a/flash/core/integrations/vissl/transforms/utilities.py b/flash/image/embedding/vissl/transforms/utilities.py similarity index 97% rename from flash/core/integrations/vissl/transforms/utilities.py rename to flash/image/embedding/vissl/transforms/utilities.py index 3590011947..b3e94d2378 100644 --- a/flash/core/integrations/vissl/transforms/utilities.py +++ b/flash/image/embedding/vissl/transforms/utilities.py @@ -16,7 +16,7 @@ from flash.core.data.data_source import DefaultDataKeys -def vissl_collate_fn(samples): +def multicrop_collate_fn(samples): """Custom collate function for VISSL integration. Run custom collate on a single key since VISSL transforms affect only DefaultDataKeys.INPUT diff --git a/flash_examples/image_embedder.py b/flash_examples/image_embedder.py index 5a4de94fcf..e34c799037 100644 --- a/flash_examples/image_embedder.py +++ b/flash_examples/image_embedder.py @@ -11,15 +11,51 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch +from torchvision.datasets import CIFAR10 + +import flash +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.transforms import ApplyToKeys from flash.core.data.utils import download_data -from flash.image import ImageEmbedder +from flash.image import ImageClassificationData, ImageEmbedder +from flash.image.embedding.transforms import IMAGE_EMBEDDER_TRANSFORMS +from flash.image.embedding.vissl.transforms import multicrop_collate_fn -# 1. Download the data -download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") +# 1. Download the data and pre-process the data +transform = IMAGE_EMBEDDER_TRANSFORMS.get("simclr_transform")() + +to_tensor_transform = ApplyToKeys( + DefaultDataKeys.INPUT, + transform, +) + +datamodule = ImageClassificationData.from_datasets( + train_dataset=CIFAR10(".", download=True), + train_transform={ + "to_tensor_transform": to_tensor_transform, + "collate": multicrop_collate_fn, + }, + batch_size=16, +) # 2. Build the task -embedder = ImageEmbedder(backbone="resnet101") +embedder = ImageEmbedder( + backbone="resnet", + training_strategy="barlow_twins", + head="simclr_head", + latent_embedding_dim=128, +) + +# 3. Create the trainer and pre-train the encoder +trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) +trainer.fit(embedder, datamodule=datamodule) + +# 4. Save the model! +trainer.save_checkpoint("image_embedder_model.pt") + +# 5. Download the downstream prediction dataset and generate embeddings +download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") -# 3. Generate an embedding from an image path. embeddings = embedder.predict(["data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg"]) print(embeddings) diff --git a/tests/core/integrations/vissl/test_transforms.py b/tests/core/integrations/vissl/test_transforms.py index d40913f58f..06b6d1efa3 100644 --- a/tests/core/integrations/vissl/test_transforms.py +++ b/tests/core/integrations/vissl/test_transforms.py @@ -14,18 +14,8 @@ import pytest from flash.core.data.data_source import DefaultDataKeys -from flash.core.data.process import DefaultPreprocess -from flash.core.data.transforms import ApplyToKeys from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, _VISSL_AVAILABLE -from flash.image import ImageClassificationData - -if _TORCHVISION_AVAILABLE: - from torchvision.datasets import FakeData - -if _VISSL_AVAILABLE: - from classy_vision.dataset.transforms import TRANSFORM_REGISTRY - - from flash.core.integrations.vissl.transforms import vissl_collate_fn +from tests.image.embedding.utils import ssl_datamodule @pytest.mark.skipif(not (_TORCHVISION_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.") @@ -36,28 +26,13 @@ def test_multicrop_input_transform(): size_crops = [160, 96] crop_scales = [[0.4, 1], [0.05, 0.4]] - multi_crop_transform = TRANSFORM_REGISTRY["multicrop_ssl_transform"]( - total_crops, num_crops, size_crops, crop_scales - ) - - to_tensor_transform = ApplyToKeys( - DefaultDataKeys.INPUT, - multi_crop_transform, - ) - preprocess = DefaultPreprocess( - train_transform={ - "to_tensor_transform": to_tensor_transform, - "collate": vissl_collate_fn, - } - ) - - datamodule = ImageClassificationData.from_datasets( - train_dataset=FakeData(), - preprocess=preprocess, + train_dataloader = ssl_datamodule( batch_size=batch_size, - ) - - train_dataloader = datamodule._train_dataloader() + total_crops=total_crops, + num_crops=num_crops, + size_crops=size_crops, + crop_scales=crop_scales, + )._train_dataloader() batch = next(iter(train_dataloader)) assert len(batch[DefaultDataKeys.INPUT]) == total_crops diff --git a/tests/image/embedding/test_model.py b/tests/image/embedding/test_model.py index e823212ef7..6a3ca358d1 100644 --- a/tests/image/embedding/test_model.py +++ b/tests/image/embedding/test_model.py @@ -11,36 +11,57 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import re import pytest import torch -from flash.core.utilities.imports import _IMAGE_AVAILABLE +import flash +from flash.core.utilities.imports import _IMAGE_AVAILABLE, _TORCHVISION_AVAILABLE, _VISSL_AVAILABLE from flash.image import ImageEmbedder -from tests.helpers.utils import _IMAGE_TESTING +from tests.image.embedding.utils import ssl_datamodule - -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") -@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32),))]) -def test_jit(tmpdir, jitter, args): - path = os.path.join(tmpdir, "test.pt") - - model = ImageEmbedder(embedding_dim=128) - model.eval() - - model = jitter(model, *args) - - torch.jit.save(model, path) - model = torch.jit.load(path) - - out = model(torch.rand(1, 3, 32, 32)) - assert isinstance(out, torch.Tensor) - assert out.shape == torch.Size([1, 128]) +# TODO: Figure out why VISSL can't be jitted +# @pytest.mark.skipif(not (_TORCHVISION_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.") +# @pytest.mark.parametrize("jitter, args", [(torch.jit.trace, (torch.rand(1, 3, 64, 64),))]) +# def test_jit(tmpdir, jitter, args): +# path = os.path.join(tmpdir, "test.pt") +# +# model = ImageEmbedder(training_strategy="barlow_twins") +# model.eval() +# +# model = jitter(model, *args) +# +# torch.jit.save(model, path) +# model = torch.jit.load(path) +# +# out = model(torch.rand(1, 3, 64, 64)) +# assert isinstance(out, torch.Tensor) +# assert out.shape == torch.Size([1, 2048]) @pytest.mark.skipif(_IMAGE_AVAILABLE, reason="image libraries are installed.") def test_load_from_checkpoint_dependency_error(): with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[image]'")): ImageEmbedder.load_from_checkpoint("not_a_real_checkpoint.pt") + + +@pytest.mark.skipif(not (_TORCHVISION_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.") +@pytest.mark.parametrize("backbone, training_strategy", [("resnet", "barlow_twins")]) +def test_vissl_training(tmpdir, backbone, training_strategy): + datamodule = ssl_datamodule( + total_crops=2, + num_crops=[2], + size_crops=[96], + crop_scales=[[0.4, 1]], + ) + + embedder = ImageEmbedder( + backbone=backbone, + training_strategy=training_strategy, + head="simclr_head", + latent_embedding_dim=128, + ) + + trainer = flash.Trainer(max_steps=3, max_epochs=1, gpus=torch.cuda.device_count()) + trainer.fit(embedder, datamodule=datamodule) diff --git a/tests/image/embedding/utils.py b/tests/image/embedding/utils.py new file mode 100644 index 0000000000..1db9cbe3f6 --- /dev/null +++ b/tests/image/embedding/utils.py @@ -0,0 +1,44 @@ +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.process import DefaultPreprocess +from flash.core.data.transforms import ApplyToKeys +from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, _VISSL_AVAILABLE +from flash.image import ImageClassificationData +from flash.image.embedding.vissl.transforms import multicrop_collate_fn + +if _TORCHVISION_AVAILABLE: + from torchvision.datasets import FakeData + +if _VISSL_AVAILABLE: + from classy_vision.dataset.transforms import TRANSFORM_REGISTRY + + +def ssl_datamodule( + batch_size=2, + total_crops=4, + num_crops=[2, 2], + size_crops=[160, 96], + crop_scales=[[0.4, 1], [0.05, 0.4]], + collate_fn=multicrop_collate_fn, +): + multi_crop_transform = TRANSFORM_REGISTRY["multicrop_ssl_transform"]( + total_crops, num_crops, size_crops, crop_scales + ) + + to_tensor_transform = ApplyToKeys( + DefaultDataKeys.INPUT, + multi_crop_transform, + ) + preprocess = DefaultPreprocess( + train_transform={ + "to_tensor_transform": to_tensor_transform, + "collate": collate_fn, + } + ) + + datamodule = ImageClassificationData.from_datasets( + train_dataset=FakeData(), + preprocess=preprocess, + batch_size=batch_size, + ) + + return datamodule