Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into feature/768-text-from-data-frame
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Sep 22, 2021
2 parents 397bd71 + 40cb3ab commit 2479d08
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 37 deletions.
32 changes: 32 additions & 0 deletions docs/source/reference/object_detection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,35 @@ To view configuration options and options for running the object detector with y
.. code-block:: bash
flash object_detection --help
------

**********************
Custom Transformations
**********************

Flash automatically applies some default image / mask transformations and augmentations, but you may wish to customize these for your own use case.
The base :class:`~flash.core.data.process.Preprocess` defines 7 hooks for different stages in the data loading pipeline.
For object-detection tasks, you can leverage the transformations from `Albumentations <https://github.com/albumentations-team/albumentations>`__ with the :class:`~flash.core.integrations.icevision.transforms.IceVisionTransformAdapter`.

.. code-block:: python
import albumentations as alb
from icevision.tfms import A
from flash.core.integrations.icevision.transforms import IceVisionTransformAdapter
from flash.image import ObjectDetectionData
train_transform = {
"pre_tensor_transform": transforms.IceVisionTransformAdapter(
[*A.resize_and_pad(128), A.Normalize(), A.Flip(0.4), alb.RandomBrightnessContrast()]
)
}
datamodule = ObjectDetectionData.from_coco(
train_folder="data/coco128/images/train2017/",
train_ann_file="data/coco128/annotations/instances_train2017.json",
val_split=0.1,
image_size=128,
train_transform=train_transform,
)
36 changes: 30 additions & 6 deletions flash/core/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_epoch_end(self, outputs) -> None:
def process_train_dataset(
self,
dataset: BaseAutoDataset,
trainer: flash.Trainer,
trainer: "flash.Trainer",
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand All @@ -120,13 +120,21 @@ def process_train_dataset(
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self.adapter.process_train_dataset(
dataset, trainer, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
dataset,
trainer,
batch_size,
num_workers,
pin_memory,
collate_fn,
shuffle,
drop_last,
sampler,
)

def process_val_dataset(
self,
dataset: BaseAutoDataset,
trainer: flash.Trainer,
trainer: "flash.Trainer",
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand All @@ -136,13 +144,21 @@ def process_val_dataset(
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self.adapter.process_val_dataset(
dataset, trainer, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
dataset,
trainer,
batch_size,
num_workers,
pin_memory,
collate_fn,
shuffle,
drop_last,
sampler,
)

def process_test_dataset(
self,
dataset: BaseAutoDataset,
trainer: flash.Trainer,
trainer: "flash.Trainer",
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand All @@ -152,7 +168,15 @@ def process_test_dataset(
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self.adapter.process_test_dataset(
dataset, trainer, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
dataset,
trainer,
batch_size,
num_workers,
pin_memory,
collate_fn,
shuffle,
drop_last,
sampler,
)

def process_predict_dataset(
Expand Down
8 changes: 7 additions & 1 deletion flash/core/integrations/icevision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,13 @@ def from_icevision_predictions(predictions: List["Prediction"]):


class IceVisionTransformAdapter(nn.Module):
def __init__(self, transform):
"""
Args:
transform: list of transformation functions to apply
"""

def __init__(self, transform: List[Callable]):
super().__init__()
self.transform = A.Adapter(transform)

Expand Down
9 changes: 3 additions & 6 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def process_train_dataset(
shuffle: bool = False,
drop_last: bool = True,
sampler: Optional[Sampler] = None,
persistent_workers: bool = True,
) -> DataLoader:
return self._process_dataset(
dataset,
Expand All @@ -155,7 +154,7 @@ def process_train_dataset(
shuffle=shuffle,
drop_last=drop_last,
sampler=sampler,
persistent_workers=persistent_workers and num_workers > 0,
persistent_workers=num_workers > 0,
)

def process_val_dataset(
Expand All @@ -169,7 +168,6 @@ def process_val_dataset(
shuffle: bool = False,
drop_last: bool = False,
sampler: Optional[Sampler] = None,
persistent_workers: bool = True,
) -> DataLoader:
return self._process_dataset(
dataset,
Expand All @@ -180,7 +178,7 @@ def process_val_dataset(
shuffle=shuffle,
drop_last=drop_last,
sampler=sampler,
persistent_workers=persistent_workers and num_workers > 0,
persistent_workers=num_workers > 0,
)

def process_test_dataset(
Expand All @@ -194,7 +192,6 @@ def process_test_dataset(
shuffle: bool = False,
drop_last: bool = False,
sampler: Optional[Sampler] = None,
persistent_workers: bool = True,
) -> DataLoader:
return self._process_dataset(
dataset,
Expand All @@ -205,7 +202,7 @@ def process_test_dataset(
shuffle=shuffle,
drop_last=drop_last,
sampler=sampler,
persistent_workers=persistent_workers and num_workers > 0,
persistent_workers=num_workers > 0,
)

def process_predict_dataset(
Expand Down
16 changes: 2 additions & 14 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import types
from importlib.util import find_spec
from typing import List, Union
from warnings import warn

from pkg_resources import DistributionNotFound

Expand Down Expand Up @@ -107,19 +106,8 @@ def _compare_version(package: str, op, version) -> bool:
from PIL import Image # noqa: F401
else:

class MetaImage(type):
def __init__(cls, name, bases, dct):
super().__init__(name, bases, dct)

cls._Image = None

@property
def Image(cls):
warn("Mock object called due to missing PIL library. Please use \"pip install 'lightning-flash[image]'\".")
return cls._Image

class Image(metaclass=MetaImage):
pass
class Image:
Image = object


if Version:
Expand Down
17 changes: 7 additions & 10 deletions flash/image/classification/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _labels_to_indices(data):

def _convert_dataset(
self,
trainer: flash.Trainer,
trainer: "flash.Trainer",
dataset: BaseAutoDataset,
ways: int,
shots: int,
Expand Down Expand Up @@ -334,14 +334,14 @@ def _sanetize_batch_size(self, batch_size: int) -> int:
def process_train_dataset(
self,
dataset: BaseAutoDataset,
trainer: flash.Trainer,
trainer: "flash.Trainer",
batch_size: int,
num_workers: int,
pin_memory: bool,
collate_fn: Callable,
shuffle: bool,
drop_last: bool,
sampler: Optional[Sampler],
shuffle: bool = False,
drop_last: bool = False,
sampler: Optional[Sampler] = None,
) -> DataLoader:
dataset = self._convert_dataset(
trainer=trainer,
Expand All @@ -366,13 +366,12 @@ def process_train_dataset(
shuffle=shuffle,
drop_last=drop_last,
sampler=sampler,
persistent_workers=True,
)

def process_val_dataset(
self,
dataset: BaseAutoDataset,
trainer: flash.Trainer,
trainer: "flash.Trainer",
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand Down Expand Up @@ -404,13 +403,12 @@ def process_val_dataset(
shuffle=shuffle,
drop_last=drop_last,
sampler=sampler,
persistent_workers=True,
)

def process_test_dataset(
self,
dataset: BaseAutoDataset,
trainer: flash.Trainer,
trainer: "flash.Trainer",
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand Down Expand Up @@ -442,7 +440,6 @@ def process_test_dataset(
shuffle=shuffle,
drop_last=drop_last,
sampler=sampler,
persistent_workers=True,
)

def process_predict_dataset(
Expand Down

0 comments on commit 2479d08

Please sign in to comment.