Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into feature/577_add_SQuADDataSource
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Aug 31, 2021
2 parents c149c37 + c512c31 commit a8b24d6
Show file tree
Hide file tree
Showing 17 changed files with 316 additions and 12 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ jobs:
restore-keys: |
${{ runner.os }}-${{ matrix.python-version }}-${{ join(matrix.topic,'-') }}-${{ matrix.requires }}-pip-
- name: Install vissl
if: matrix.topic[1] == 'image_extras'
run: |
pip install git+https://github.com/facebookresearch/vissl.git@master
- name: Install dependencies
run: |
python --version
Expand Down
2 changes: 2 additions & 0 deletions docs/source/reference/summarization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ You can now perform inference from your client like this:

------

.. _summarization_ort:

**********************************************
Accelerate Training & Inference with Torch ORT
**********************************************
Expand Down
2 changes: 2 additions & 0 deletions docs/source/reference/text_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ You can now perform inference from your client like this:

------

.. _text_classification_ort:

**********************************************
Accelerate Training & Inference with Torch ORT
**********************************************
Expand Down
2 changes: 2 additions & 0 deletions docs/source/reference/translation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ You can now perform inference from your client like this:

------

.. _translation_ort:

**********************************************
Accelerate Training & Inference with Torch ORT
**********************************************
Expand Down
29 changes: 18 additions & 11 deletions flash/core/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,26 @@ def __init__(self, keys: Union[str, Sequence[str]], *args):
def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]:
keys = list(filter(lambda key: key in x, self.keys))
inputs = [x[key] for key in keys]
if len(inputs) > 0:
if len(inputs) == 1:
inputs = inputs[0]
outputs = super().forward(inputs)
if not isinstance(outputs, Sequence):
outputs = (outputs,)

result = {}
result.update(x)

result = {}
result.update(x)

if len(inputs) == 1:
result[keys[0]] = super().forward(inputs[0])
elif len(inputs) > 1:
try:
outputs = super().forward(inputs)
except TypeError as e:
raise Exception(
"Failed to apply transforms to multiple keys at the same time,"
" try using KorniaParallelTransforms."
) from e

for i, key in enumerate(keys):
result[key] = outputs[i]
return result
return x

# result is simply returned if len(inputs) == 0
return result

def __repr__(self):
transform = list(self.children())
Expand Down
Empty file.
9 changes: 9 additions & 0 deletions flash/core/integrations/vissl/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from flash.core.utilities.imports import _VISSL_AVAILABLE # noqa: F401

if _VISSL_AVAILABLE:
from classy_vision.dataset.transforms import register_transform # noqa: F401

from flash.core.integrations.vissl.transforms.multicrop import StandardMultiCropSSLTransform # noqa: F401
from flash.core.integrations.vissl.transforms.utilities import vissl_collate_fn # noqa: F401

register_transform("multicrop_ssl_transform")(StandardMultiCropSSLTransform)
122 changes: 122 additions & 0 deletions flash/core/integrations/vissl/transforms/multicrop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Sequence

import numpy as np
import torch.nn as nn

from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, Image

if _TORCHVISION_AVAILABLE:
import torchvision.transforms as pth_transforms


class StandardMultiCropSSLTransform(nn.Module):
"""Convert a PIL image to Multi-resolution Crops. The input is a PIL image and output is the list of image
crops.
This transform was proposed in SwAV - https://arxiv.org/abs/2006.09882
This transform has been modified from the ImgPilToMultiCrop code present at
https://github.com/facebookresearch/vissl/blob/master/vissl/data/ssl_transforms/img_pil_to_multicrop.py
"""

def __init__(
self,
total_num_crops: int,
num_crops: Sequence[int],
size_crops: Sequence[int],
crop_scales: Sequence[Sequence[float]],
gaussian_blur: bool = True,
jitter_strength: float = 1.0,
normalize: Optional[nn.Module] = None,
):
"""Returns total_num_crops square crops of an image. Each crop is a random crop extracted according to the
parameters specified in size_crops and crop_scales. For ease of use, one can specify `num_crops` which
removes the need to repeat parameters.
Args:
total_num_crops (int): Total number of crops to extract
num_crops (List or Tuple of ints): Specifies the number of `type' of crops.
size_crops (List or Tuple of ints): Specifies the height (height = width)
of each patch
crop_scales (List or Tuple containing [float, float]): Scale of the crop
gaussian_blur (bool): Specifies if the transforms composition has Gaussian Blur
jitter_strength (float): Specify the coefficient for color jitter transform
normalize (Optional): Normalize transform from torchvision with params set
according to the dataset
Example usage:
- (total_num_crops=2, num_crops=[1, 1],
size_crops=[224, 96], crop_scales=[(0.14, 1.), (0.05, 0.14)])
Extracts 2 crops total of size 224x224 and 96x96
- (total_num_crops=3, num_crops=[1, 2],
size_crops=[224, 96], crop_scales=[(0.14, 1.), (0.05, 0.14)])
Extracts 3 crops total: 1 of size 224x224 and 2 of size 96x96
"""
super().__init__()

assert np.sum(num_crops) == total_num_crops
assert len(size_crops) == len(num_crops)
assert len(size_crops) == len(crop_scales)

self.gaussian_blur = gaussian_blur
self.jitter_strength = jitter_strength
self.normalize = normalize

color_jitter = pth_transforms.ColorJitter(
0.8 * self.jitter_strength,
0.8 * self.jitter_strength,
0.8 * self.jitter_strength,
0.2 * self.jitter_strength,
)
color_transform = [pth_transforms.RandomApply([color_jitter], p=0.8), pth_transforms.RandomGrayscale(p=0.2)]

if self.gaussian_blur:
kernel_size = int(0.1 * size_crops[0])
if kernel_size % 2 == 0:
kernel_size += 1

color_transform.append(
pth_transforms.RandomApply([pth_transforms.GaussianBlur(kernel_size=kernel_size)], p=0.5)
)

self.color_transform = pth_transforms.Compose(color_transform)

if normalize is None:
self.final_transform = pth_transforms.ToTensor()
else:
self.final_transform = pth_transforms.Compose([pth_transforms.ToTensor(), normalize])

transforms = []
for num, size, scale in zip(num_crops, size_crops, crop_scales):
transforms.extend(
[
pth_transforms.Compose(
[
pth_transforms.RandomResizedCrop(size, scale=scale),
pth_transforms.RandomHorizontalFlip(p=0.5),
self.color_transform,
self.final_transform,
]
)
]
* num
)

self.transforms = transforms

def __call__(self, image: Image.Image) -> List[Image.Image]:
images = [transform(image) for transform in self.transforms]
return images
47 changes: 47 additions & 0 deletions flash/core/integrations/vissl/transforms/utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

from flash.core.data.data_source import DefaultDataKeys


def vissl_collate_fn(samples):
"""Custom collate function for VISSL integration.
Run custom collate on a single key since VISSL transforms affect only DefaultDataKeys.INPUT
"""
result = []

for batch_ele in samples:
_batch_ele_dict = {}
_batch_ele_dict.update(batch_ele)
_batch_ele_dict[DefaultDataKeys.INPUT] = -1

result.append(_batch_ele_dict)

result = torch.utils.data._utils.collate.default_collate(result)

inputs = [[] for _ in range(len(samples[0][DefaultDataKeys.INPUT]))]
for batch_ele in samples:
multi_crop_imgs = batch_ele[DefaultDataKeys.INPUT]

for idx, crop in enumerate(multi_crop_imgs):
inputs[idx].append(crop)

for idx, ele in enumerate(inputs):
inputs[idx] = torch.stack(ele)

result[DefaultDataKeys.INPUT] = inputs

return result
23 changes: 23 additions & 0 deletions flash/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,26 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) ->
# the lightning trainer implementation does not support subclasses.
# context: https://github.com/PyTorchLightning/lightning-flash/issues/342#issuecomment-848892447
return from_argparse_args(Trainer, args, **kwargs)

def request_dataloader(
self,
*args,
) -> Union[DataLoader, List[DataLoader]]:
"""Handles downloading data in the GPU or TPU case.
Returns:
The dataloader
"""
if isinstance(args[0], LightningModule):
model, stage = args
self.call_hook(f"on_{stage}_dataloader")
dataloader = getattr(model, f"{stage}_dataloader")()
else:
stage, model = args
hook = f"{stage.dataloader_prefix}_dataloader"
self.call_hook("on_" + hook, pl_module=model)
dataloader = self.call_hook(hook, pl_module=model)
if isinstance(dataloader, tuple):
dataloader = list(dataloader)
self.accelerator.barrier("get_dataloaders")
return dataloader
1 change: 1 addition & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def _compare_version(package: str, op, version) -> bool:
_ICEVISION_AVAILABLE = _module_available("icevision")
_ICEDATA_AVAILABLE = _module_available("icedata")
_TORCH_ORT_AVAILABLE = _module_available("torch_ort")
_VISSL_AVAILABLE = _module_available("vissl") and _module_available("classy_vision")

if _PIL_AVAILABLE:
from PIL import Image
Expand Down
5 changes: 4 additions & 1 deletion flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
)
from flash.image.data import ImageDeserializer, IMG_EXTENSIONS
from flash.image.segmentation.serialization import SegmentationLabels
from flash.image.segmentation.transforms import default_transforms, train_default_transforms
from flash.image.segmentation.transforms import default_transforms, predict_default_transforms, train_default_transforms

SampleCollection = None
if _FIFTYONE_AVAILABLE:
Expand Down Expand Up @@ -284,6 +284,9 @@ def default_transforms(self) -> Optional[Dict[str, Callable]]:
def train_default_transforms(self) -> Optional[Dict[str, Callable]]:
return train_default_transforms(self.image_size)

def predict_default_transforms(self) -> Optional[Dict[str, Callable]]:
return predict_default_transforms(self.image_size)


class SemanticSegmentationData(DataModule):
"""Data module for semantic segmentation tasks."""
Expand Down
13 changes: 13 additions & 0 deletions flash/image/segmentation/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,16 @@ def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]
),
},
)


def predict_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]:
"""During predict, we apply the default transforms only on DefaultDataKeys.INPUT."""
return {
"post_tensor_transform": nn.Sequential(
ApplyToKeys(
DefaultDataKeys.INPUT,
K.geometry.Resize(image_size, interpolation="nearest"),
),
),
"collate": kornia_collate,
}
2 changes: 2 additions & 0 deletions requirements/datatype_image_extras.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
matplotlib
fiftyone
classy_vision
vissl>=0.1.5
Empty file.
Empty file.
Loading

0 comments on commit a8b24d6

Please sign in to comment.