diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 03d4975b1a..199cd6ab32 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -133,6 +133,11 @@ jobs: restore-keys: | ${{ runner.os }}-${{ matrix.python-version }}-${{ join(matrix.topic,'-') }}-${{ matrix.requires }}-pip- + - name: Install vissl + if: matrix.topic[1] == 'image_extras' + run: | + pip install git+https://github.com/facebookresearch/vissl.git@master + - name: Install dependencies run: | python --version diff --git a/docs/source/reference/summarization.rst b/docs/source/reference/summarization.rst index 38b1fbb437..ed4cb3bbf0 100644 --- a/docs/source/reference/summarization.rst +++ b/docs/source/reference/summarization.rst @@ -93,6 +93,8 @@ You can now perform inference from your client like this: ------ +.. _summarization_ort: + ********************************************** Accelerate Training & Inference with Torch ORT ********************************************** diff --git a/docs/source/reference/text_classification.rst b/docs/source/reference/text_classification.rst index 49c69a4f63..e2142819b3 100644 --- a/docs/source/reference/text_classification.rst +++ b/docs/source/reference/text_classification.rst @@ -93,6 +93,8 @@ You can now perform inference from your client like this: ------ +.. _text_classification_ort: + ********************************************** Accelerate Training & Inference with Torch ORT ********************************************** diff --git a/docs/source/reference/translation.rst b/docs/source/reference/translation.rst index e3422d8cb6..bc37ad67eb 100644 --- a/docs/source/reference/translation.rst +++ b/docs/source/reference/translation.rst @@ -93,6 +93,8 @@ You can now perform inference from your client like this: ------ +.. _translation_ort: + ********************************************** Accelerate Training & Inference with Torch ORT ********************************************** diff --git a/flash/core/data/transforms.py b/flash/core/data/transforms.py index 759c1bbc1e..aad996fdfe 100644 --- a/flash/core/data/transforms.py +++ b/flash/core/data/transforms.py @@ -39,19 +39,26 @@ def __init__(self, keys: Union[str, Sequence[str]], *args): def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]: keys = list(filter(lambda key: key in x, self.keys)) inputs = [x[key] for key in keys] - if len(inputs) > 0: - if len(inputs) == 1: - inputs = inputs[0] - outputs = super().forward(inputs) - if not isinstance(outputs, Sequence): - outputs = (outputs,) - - result = {} - result.update(x) + + result = {} + result.update(x) + + if len(inputs) == 1: + result[keys[0]] = super().forward(inputs[0]) + elif len(inputs) > 1: + try: + outputs = super().forward(inputs) + except TypeError as e: + raise Exception( + "Failed to apply transforms to multiple keys at the same time," + " try using KorniaParallelTransforms." + ) from e + for i, key in enumerate(keys): result[key] = outputs[i] - return result - return x + + # result is simply returned if len(inputs) == 0 + return result def __repr__(self): transform = list(self.children()) diff --git a/flash/core/integrations/vissl/__init__.py b/flash/core/integrations/vissl/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flash/core/integrations/vissl/transforms/__init__.py b/flash/core/integrations/vissl/transforms/__init__.py new file mode 100644 index 0000000000..804689456e --- /dev/null +++ b/flash/core/integrations/vissl/transforms/__init__.py @@ -0,0 +1,9 @@ +from flash.core.utilities.imports import _VISSL_AVAILABLE # noqa: F401 + +if _VISSL_AVAILABLE: + from classy_vision.dataset.transforms import register_transform # noqa: F401 + + from flash.core.integrations.vissl.transforms.multicrop import StandardMultiCropSSLTransform # noqa: F401 + from flash.core.integrations.vissl.transforms.utilities import vissl_collate_fn # noqa: F401 + + register_transform("multicrop_ssl_transform")(StandardMultiCropSSLTransform) diff --git a/flash/core/integrations/vissl/transforms/multicrop.py b/flash/core/integrations/vissl/transforms/multicrop.py new file mode 100644 index 0000000000..f6dda5c8b5 --- /dev/null +++ b/flash/core/integrations/vissl/transforms/multicrop.py @@ -0,0 +1,122 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Sequence + +import numpy as np +import torch.nn as nn + +from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, Image + +if _TORCHVISION_AVAILABLE: + import torchvision.transforms as pth_transforms + + +class StandardMultiCropSSLTransform(nn.Module): + """Convert a PIL image to Multi-resolution Crops. The input is a PIL image and output is the list of image + crops. + + This transform was proposed in SwAV - https://arxiv.org/abs/2006.09882 + + This transform has been modified from the ImgPilToMultiCrop code present at + https://github.com/facebookresearch/vissl/blob/master/vissl/data/ssl_transforms/img_pil_to_multicrop.py + """ + + def __init__( + self, + total_num_crops: int, + num_crops: Sequence[int], + size_crops: Sequence[int], + crop_scales: Sequence[Sequence[float]], + gaussian_blur: bool = True, + jitter_strength: float = 1.0, + normalize: Optional[nn.Module] = None, + ): + """Returns total_num_crops square crops of an image. Each crop is a random crop extracted according to the + parameters specified in size_crops and crop_scales. For ease of use, one can specify `num_crops` which + removes the need to repeat parameters. + + Args: + total_num_crops (int): Total number of crops to extract + num_crops (List or Tuple of ints): Specifies the number of `type' of crops. + size_crops (List or Tuple of ints): Specifies the height (height = width) + of each patch + crop_scales (List or Tuple containing [float, float]): Scale of the crop + gaussian_blur (bool): Specifies if the transforms composition has Gaussian Blur + jitter_strength (float): Specify the coefficient for color jitter transform + normalize (Optional): Normalize transform from torchvision with params set + according to the dataset + + Example usage: + - (total_num_crops=2, num_crops=[1, 1], + size_crops=[224, 96], crop_scales=[(0.14, 1.), (0.05, 0.14)]) + Extracts 2 crops total of size 224x224 and 96x96 + - (total_num_crops=3, num_crops=[1, 2], + size_crops=[224, 96], crop_scales=[(0.14, 1.), (0.05, 0.14)]) + Extracts 3 crops total: 1 of size 224x224 and 2 of size 96x96 + """ + super().__init__() + + assert np.sum(num_crops) == total_num_crops + assert len(size_crops) == len(num_crops) + assert len(size_crops) == len(crop_scales) + + self.gaussian_blur = gaussian_blur + self.jitter_strength = jitter_strength + self.normalize = normalize + + color_jitter = pth_transforms.ColorJitter( + 0.8 * self.jitter_strength, + 0.8 * self.jitter_strength, + 0.8 * self.jitter_strength, + 0.2 * self.jitter_strength, + ) + color_transform = [pth_transforms.RandomApply([color_jitter], p=0.8), pth_transforms.RandomGrayscale(p=0.2)] + + if self.gaussian_blur: + kernel_size = int(0.1 * size_crops[0]) + if kernel_size % 2 == 0: + kernel_size += 1 + + color_transform.append( + pth_transforms.RandomApply([pth_transforms.GaussianBlur(kernel_size=kernel_size)], p=0.5) + ) + + self.color_transform = pth_transforms.Compose(color_transform) + + if normalize is None: + self.final_transform = pth_transforms.ToTensor() + else: + self.final_transform = pth_transforms.Compose([pth_transforms.ToTensor(), normalize]) + + transforms = [] + for num, size, scale in zip(num_crops, size_crops, crop_scales): + transforms.extend( + [ + pth_transforms.Compose( + [ + pth_transforms.RandomResizedCrop(size, scale=scale), + pth_transforms.RandomHorizontalFlip(p=0.5), + self.color_transform, + self.final_transform, + ] + ) + ] + * num + ) + + self.transforms = transforms + + def __call__(self, image: Image.Image) -> List[Image.Image]: + images = [transform(image) for transform in self.transforms] + return images diff --git a/flash/core/integrations/vissl/transforms/utilities.py b/flash/core/integrations/vissl/transforms/utilities.py new file mode 100644 index 0000000000..3590011947 --- /dev/null +++ b/flash/core/integrations/vissl/transforms/utilities.py @@ -0,0 +1,47 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from flash.core.data.data_source import DefaultDataKeys + + +def vissl_collate_fn(samples): + """Custom collate function for VISSL integration. + + Run custom collate on a single key since VISSL transforms affect only DefaultDataKeys.INPUT + """ + result = [] + + for batch_ele in samples: + _batch_ele_dict = {} + _batch_ele_dict.update(batch_ele) + _batch_ele_dict[DefaultDataKeys.INPUT] = -1 + + result.append(_batch_ele_dict) + + result = torch.utils.data._utils.collate.default_collate(result) + + inputs = [[] for _ in range(len(samples[0][DefaultDataKeys.INPUT]))] + for batch_ele in samples: + multi_crop_imgs = batch_ele[DefaultDataKeys.INPUT] + + for idx, crop in enumerate(multi_crop_imgs): + inputs[idx].append(crop) + + for idx, ele in enumerate(inputs): + inputs[idx] = torch.stack(ele) + + result[DefaultDataKeys.INPUT] = inputs + + return result diff --git a/flash/core/trainer.py b/flash/core/trainer.py index e376e3316b..2904b8a048 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -220,3 +220,26 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> # the lightning trainer implementation does not support subclasses. # context: https://github.com/PyTorchLightning/lightning-flash/issues/342#issuecomment-848892447 return from_argparse_args(Trainer, args, **kwargs) + + def request_dataloader( + self, + *args, + ) -> Union[DataLoader, List[DataLoader]]: + """Handles downloading data in the GPU or TPU case. + + Returns: + The dataloader + """ + if isinstance(args[0], LightningModule): + model, stage = args + self.call_hook(f"on_{stage}_dataloader") + dataloader = getattr(model, f"{stage}_dataloader")() + else: + stage, model = args + hook = f"{stage.dataloader_prefix}_dataloader" + self.call_hook("on_" + hook, pl_module=model) + dataloader = self.call_hook(hook, pl_module=model) + if isinstance(dataloader, tuple): + dataloader = list(dataloader) + self.accelerator.barrier("get_dataloaders") + return dataloader diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 8a1bdbd13f..621ea5bb2b 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -99,6 +99,7 @@ def _compare_version(package: str, op, version) -> bool: _ICEVISION_AVAILABLE = _module_available("icevision") _ICEDATA_AVAILABLE = _module_available("icedata") _TORCH_ORT_AVAILABLE = _module_available("torch_ort") +_VISSL_AVAILABLE = _module_available("vissl") and _module_available("classy_vision") if _PIL_AVAILABLE: from PIL import Image diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index 47f94e15ef..1838af6a78 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -46,7 +46,7 @@ ) from flash.image.data import ImageDeserializer, IMG_EXTENSIONS from flash.image.segmentation.serialization import SegmentationLabels -from flash.image.segmentation.transforms import default_transforms, train_default_transforms +from flash.image.segmentation.transforms import default_transforms, predict_default_transforms, train_default_transforms SampleCollection = None if _FIFTYONE_AVAILABLE: @@ -284,6 +284,9 @@ def default_transforms(self) -> Optional[Dict[str, Callable]]: def train_default_transforms(self) -> Optional[Dict[str, Callable]]: return train_default_transforms(self.image_size) + def predict_default_transforms(self) -> Optional[Dict[str, Callable]]: + return predict_default_transforms(self.image_size) + class SemanticSegmentationData(DataModule): """Data module for semantic segmentation tasks.""" diff --git a/flash/image/segmentation/transforms.py b/flash/image/segmentation/transforms.py index 53bd0a6314..8d2f301729 100644 --- a/flash/image/segmentation/transforms.py +++ b/flash/image/segmentation/transforms.py @@ -61,3 +61,16 @@ def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable] ), }, ) + + +def predict_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: + """During predict, we apply the default transforms only on DefaultDataKeys.INPUT.""" + return { + "post_tensor_transform": nn.Sequential( + ApplyToKeys( + DefaultDataKeys.INPUT, + K.geometry.Resize(image_size, interpolation="nearest"), + ), + ), + "collate": kornia_collate, + } diff --git a/requirements/datatype_image_extras.txt b/requirements/datatype_image_extras.txt index f61e3f9c25..4755ff09f0 100644 --- a/requirements/datatype_image_extras.txt +++ b/requirements/datatype_image_extras.txt @@ -1,2 +1,4 @@ matplotlib fiftyone +classy_vision +vissl>=0.1.5 diff --git a/tests/core/integrations/__init__.py b/tests/core/integrations/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/core/integrations/vissl/__init__.py b/tests/core/integrations/vissl/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/core/integrations/vissl/test_transforms.py b/tests/core/integrations/vissl/test_transforms.py new file mode 100644 index 0000000000..d40913f58f --- /dev/null +++ b/tests/core/integrations/vissl/test_transforms.py @@ -0,0 +1,66 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.process import DefaultPreprocess +from flash.core.data.transforms import ApplyToKeys +from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, _VISSL_AVAILABLE +from flash.image import ImageClassificationData + +if _TORCHVISION_AVAILABLE: + from torchvision.datasets import FakeData + +if _VISSL_AVAILABLE: + from classy_vision.dataset.transforms import TRANSFORM_REGISTRY + + from flash.core.integrations.vissl.transforms import vissl_collate_fn + + +@pytest.mark.skipif(not (_TORCHVISION_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.") +def test_multicrop_input_transform(): + batch_size = 8 + total_crops = 6 + num_crops = [2, 4] + size_crops = [160, 96] + crop_scales = [[0.4, 1], [0.05, 0.4]] + + multi_crop_transform = TRANSFORM_REGISTRY["multicrop_ssl_transform"]( + total_crops, num_crops, size_crops, crop_scales + ) + + to_tensor_transform = ApplyToKeys( + DefaultDataKeys.INPUT, + multi_crop_transform, + ) + preprocess = DefaultPreprocess( + train_transform={ + "to_tensor_transform": to_tensor_transform, + "collate": vissl_collate_fn, + } + ) + + datamodule = ImageClassificationData.from_datasets( + train_dataset=FakeData(), + preprocess=preprocess, + batch_size=batch_size, + ) + + train_dataloader = datamodule._train_dataloader() + batch = next(iter(train_dataloader)) + + assert len(batch[DefaultDataKeys.INPUT]) == total_crops + assert batch[DefaultDataKeys.INPUT][0].shape == (batch_size, 3, size_crops[0], size_crops[0]) + assert batch[DefaultDataKeys.INPUT][-1].shape == (batch_size, 3, size_crops[-1], size_crops[-1]) + assert list(batch[DefaultDataKeys.TARGET].shape) == [batch_size]