Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[WIP] add style transfer task with pystiche #262

Merged
merged 72 commits into from
May 17, 2021
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
413a530
add style transfer task with pystiche
pmeier May 5, 2021
be5a893
address review comments
pmeier May 10, 2021
398bbf3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 10, 2021
61c9480
fix type hint
pmeier May 10, 2021
0a70150
allow passing style_image by path
pmeier May 10, 2021
f6b2fcc
add batch_size
pmeier May 10, 2021
edf0dff
add data_module based on image classification
pmeier May 11, 2021
4ffc734
add internal pre / post-processing
pmeier May 11, 2021
9bb45bc
bail out if val / test step is performed
pmeier May 11, 2021
e82a94c
update example
pmeier May 11, 2021
d939fad
move example from predict to finetuning
pmeier May 11, 2021
5b10dbb
remove metrics from task
pmeier May 11, 2021
2e6901a
flake8
pmeier May 11, 2021
3c16bc0
Merge branch 'master' into style-transfer
pmeier May 11, 2021
eeed004
remove unused imports
pmeier May 11, 2021
7d38a5b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 11, 2021
e844db4
remove grayscale handling
pmeier May 11, 2021
a932f86
address review comments and small fixes
pmeier May 11, 2021
fb30c16
Merge branch 'master' into style-transfer
pmeier May 11, 2021
ba091cc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 11, 2021
e399d50
streamline apply_to_input
pmeier May 12, 2021
dc80a21
fix hyper parameters saving
pmeier May 12, 2021
9f7fd41
implement custom step
pmeier May 12, 2021
eabf49b
cleanup
pmeier May 12, 2021
54ae632
Merge branch 'master' into style-transfer
pmeier May 12, 2021
361074b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 12, 2021
464a26c
add explanation to not supported phases
pmeier May 14, 2021
d78989e
temporarily use unreleased pystiche version
pmeier May 14, 2021
1b2e6e3
add missing transforms in preprocess
pmeier May 14, 2021
0feaf7a
introduce multi layer encoders as backbones
pmeier May 14, 2021
c41e38c
refactor task
pmeier May 14, 2021
de62996
add explanation for modified gram operator
pmeier May 14, 2021
36ce2da
Merge branch 'master' into style-transfer
pmeier May 14, 2021
02a0c29
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2021
081cb48
Merge branch 'master' into style-transfer
pmeier May 16, 2021
21d0817
streamline default transforms
pmeier May 16, 2021
f868489
add disabled test for finetuning example
pmeier May 16, 2021
39baffd
add documentation skeleton
pmeier May 16, 2021
7660f51
update changelog
pmeier May 16, 2021
46ff9d0
update
tchaton May 17, 2021
e230a2c
update
tchaton May 17, 2021
40fd08b
update
tchaton May 17, 2021
278a874
update
tchaton May 17, 2021
740bb22
update
tchaton May 17, 2021
3e2ad57
update
tchaton May 17, 2021
753dfd7
update
tchaton May 17, 2021
a49474a
update
tchaton May 17, 2021
55888c3
update
tchaton May 17, 2021
71b63f4
update
tchaton May 17, 2021
4ab9aae
update
tchaton May 17, 2021
4b1e7b9
update
tchaton May 17, 2021
5c3e72a
update
tchaton May 17, 2021
b2b3132
update
tchaton May 17, 2021
8d23b95
Merge branch 'master' into style-transfer
tchaton May 17, 2021
fa7c304
change skipif
tchaton May 17, 2021
35d5702
Merge branch 'style-transfer' of https://github.com/pmeier/lightning-…
tchaton May 17, 2021
ed1574f
update
tchaton May 17, 2021
f173788
update
tchaton May 17, 2021
472cb92
update
tchaton May 17, 2021
7ac8234
Merge branch 'master' into style-transfer
tchaton May 17, 2021
d2ab928
fix image size for preprocess
pmeier May 17, 2021
1a7819c
fix style transfer requirements
pmeier May 17, 2021
a4c7f0a
update
tchaton May 17, 2021
82f2005
Merge branch 'style-transfer' of https://github.com/pmeier/lightning-…
tchaton May 17, 2021
559446b
update doc
tchaton May 17, 2021
22ee835
Merge remote-tracking branch 'pmeier/style-transfer' into style-transfer
pmeier May 17, 2021
a3b95d5
fix style transfer requirements
pmeier May 17, 2021
b6e459b
Merge remote-tracking branch 'pmeier/style-transfer' into style-transfer
pmeier May 17, 2021
b8d93be
update
tchaton May 17, 2021
77db374
Merge branch 'style-transfer' of https://github.com/pmeier/lightning-…
tchaton May 17, 2021
4fbf11c
add reference to pystiche
pmeier May 17, 2021
31efb53
remove unnecessary import
pmeier May 17, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ jobs:
python-version: 3.8
requires: 'latest'
topic: 'text'
- os: ubuntu-20.04
python-version: 3.8
requires: 'latest'
topic: 'image_style_transfer'

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 35
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Refactor preprocess_cls to preprocess, add Serializer, add DataPipelineState ([#229](https://github.com/PyTorchLightning/lightning-flash/pull/229))
- Added Semantic Segmentation task ([#239](https://github.com/PyTorchLightning/lightning-flash/pull/239) [#287](https://github.com/PyTorchLightning/lightning-flash/pull/287) [#290](https://github.com/PyTorchLightning/lightning-flash/pull/290))
- Added Object detection prediction example ([#283](https://github.com/PyTorchLightning/lightning-flash/pull/283))
- Added Style Transfer task and accompanying finetuning and prediction examples ([#262](https://github.com/PyTorchLightning/lightning-flash/pull/262))

### Changed

Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Lightning Flash
reference/object_detection
reference/video_classification
reference/semantic_segmentation
reference/style_transfer

.. toctree::
:maxdepth: 1
Expand Down
74 changes: 74 additions & 0 deletions docs/source/reference/style_transfer.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
##############
Style Transfer
##############

********
The task
********

The Neural Style Transfer Task is an optimization method which extract the style from an image and apply it another image while preserving its content.
The goal is that the output image looks like the content image, but “painted” in the style of the style reference image.

.. image:: https://raw.githubusercontent.com/pystiche/pystiche/master/docs/source/graphics/banner/banner.jpg
:alt: style_transfer_example

pmeier marked this conversation as resolved.
Show resolved Hide resolved
------

***
Fit
***

First, you would have to import the :class:`~flash.image.style_transfer.StyleTransfer`
and :class:`~flash.image.style_transfer.StyleTransferData` from Flash.

.. testcode:: style_transfer

import sys
import flash
from flash.core.data.utils import download_data
from flash.image.style_transfer import StyleTransfer, StyleTransferData
import pystiche


Then, download some content images and create a :class:`~flash.image.style_transfer.StyleTransferData` DataModule.

.. testcode:: style_transfer

download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/")

data_module = StyleTransferData.from_folders(train_folder="data/coco128/images", batch_size=4)


Select a style image and pass it to the `StyleTransfer` task.

.. testcode:: style_transfer

style_image = pystiche.demo.images()["paint"].read(size=256)

model = StyleTransfer(style_image)

Finally, create a Flash :class:`flash.core.trainer.Trainer` and pass it the model and datamodule.

.. testcode:: style_transfer

trainer = flash.Trainer(max_epochs=2)
trainer.fit(model, data_module)
tchaton marked this conversation as resolved.
Show resolved Hide resolved


------

*************
API reference
*************

StyleTransfer
-------------

.. autoclass:: flash.image.StyleTransfer
:members:
:exclude-members: forward

StyleTransferData
-----------------

.. autoclass:: flash.image.StyleTransferData
7 changes: 5 additions & 2 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ def _resolve_function_hierarchy(
if object_type is None:
object_type = Preprocess

prefixes = ['']
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefixes = []

if stage in (RunningStage.TRAINING, RunningStage.TUNING):
prefixes += ['train', 'fit']
elif stage == RunningStage.VALIDATING:
Expand All @@ -190,9 +191,11 @@ def _resolve_function_hierarchy(
elif stage == RunningStage.PREDICTING:
prefixes += ['predict']

prefixes += [None]
tchaton marked this conversation as resolved.
Show resolved Hide resolved

for prefix in prefixes:
if cls._is_overriden(function_name, process_obj, object_type, prefix=prefix):
return f'{prefix}_{function_name}'
return function_name if prefix is None else f'{prefix}_{function_name}'

return function_name

Expand Down
10 changes: 3 additions & 7 deletions flash/core/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,14 +296,10 @@ def generate_dataset(

mock_dataset = typing.cast(AutoDataset, MockDataset())
with CurrentRunningStageFuncContext(running_stage, "load_data", self):
load_data: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr(
self, DataPipeline._resolve_function_hierarchy(
"load_data",
self,
running_stage,
DataSource,
)
resolved_func_name = DataPipeline._resolve_function_hierarchy(
"load_data", self, running_stage, DataSource
)
load_data: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr(self, resolved_func_name)
parameters = signature(load_data).parameters
if len(parameters) > 1 and "dataset" in parameters: # TODO: This was DATASET_KEY before
data = load_data(data, mock_dataset)
Expand Down
2 changes: 1 addition & 1 deletion flash/core/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from flash.core.data.batch import default_uncollate
from flash.core.data.callback import FlashCallback
from flash.core.data.data_source import DataSource
from flash.core.data.data_source import DataSource, DefaultDataKeys
from flash.core.data.properties import Properties
from flash.core.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX, convert_to_modules, CurrentRunningStageFuncContext

Expand Down
5 changes: 5 additions & 0 deletions flash/core/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]:
return result
return x

def __repr__(self):
keys = self.keys[0] if len(self.keys) == 1 else self.keys
transform = [c for c in self.children()]
return f"{self.__class__.__name__}(keys={keys}, transform={transform})"


class KorniaParallelTransforms(nn.Sequential):
"""The ``KorniaParallelTransforms`` class is an ``nn.Sequential`` which will apply the given transforms to each
Expand Down
3 changes: 3 additions & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,13 @@ def _compare_version(package: str, op, version) -> bool:
_PYTORCHVIDEO_AVAILABLE = _module_available("pytorchvideo")
_MATPLOTLIB_AVAILABLE = _module_available("matplotlib")
_TRANSFORMERS_AVAILABLE = _module_available("transformers")
_PYSTICHE_AVAILABLE = _module_available("pystiche")
pmeier marked this conversation as resolved.
Show resolved Hide resolved

if Version:
_TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0")
_PYSTICHE_GREATER_EQUAL_0_7_2 = _compare_version("pystiche", operator.ge, "0.7.2")

_IMAGE_STLYE_TRANSFER = _PYSTICHE_AVAILABLE
_TEXT_AVAILABLE = _TRANSFORMERS_AVAILABLE
_TABULAR_AVAILABLE = _TABNET_AVAILABLE and _PANDAS_AVAILABLE
_VIDEO_AVAILABLE = _PYTORCHVIDEO_AVAILABLE
Expand Down
1 change: 1 addition & 0 deletions flash/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from flash.image.detection import ObjectDetectionData, ObjectDetector
from flash.image.embedding import ImageEmbedder
from flash.image.segmentation import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationPreprocess
from flash.image.style_transfer import StyleTransfer, StyleTransferData, StyleTransferPreprocess
3 changes: 3 additions & 0 deletions flash/image/style_transfer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from flash.image.style_transfer.backbone import STYLE_TRANSFER_BACKBONES
from flash.image.style_transfer.data import StyleTransferData, StyleTransferPreprocess
from flash.image.style_transfer.model import StyleTransfer
24 changes: 24 additions & 0 deletions flash/image/style_transfer/backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _PYSTICHE_AVAILABLE

STYLE_TRANSFER_BACKBONES = FlashRegistry("backbones")

__all__ = ["STYLE_TRANSFER_BACKBONES"]

if _PYSTICHE_AVAILABLE:

from pystiche import enc

for mle_fn in dir(enc):

if "multi_layer_encoder" not in mle_fn:
continue

name = mle_fn.split("_")[0]
tchaton marked this conversation as resolved.
Show resolved Hide resolved

STYLE_TRANSFER_BACKBONES(
fn=lambda: (getattr(enc, mle_fn)(), None),
name=mle_fn.split("_")[0],
tchaton marked this conversation as resolved.
Show resolved Hide resolved
namespace="image/style_transfer",
package="pystiche",
)
110 changes: 110 additions & 0 deletions flash/image/style_transfer/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import functools
import pathlib
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

from torch import nn

from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources
from flash.core.data.process import Preprocess
from flash.core.data.transforms import ApplyToKeys
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE
from flash.image.classification import ImageClassificationData, ImageClassificationPreprocess
from flash.image.style_transfer.utils import raise_not_supported

if _TORCHVISION_AVAILABLE:
from torchvision import transforms as T

__all__ = ["StyleTransferPreprocess", "StyleTransferData"]


def _apply_to_input(default_transforms_fn, keys: Union[Sequence[DefaultDataKeys],
DefaultDataKeys]) -> Callable[..., Dict[str, ApplyToKeys]]:

@functools.wraps(default_transforms_fn)
def wrapper(*args: Any, **kwargs: Any) -> Optional[Dict[str, ApplyToKeys]]:
default_transforms = default_transforms_fn(*args, **kwargs)
if not default_transforms:
return default_transforms

return {hook: ApplyToKeys(keys, transform) for hook, transform in default_transforms.items()}

return wrapper


class StyleTransferPreprocess(ImageClassificationPreprocess):

def __init__(
self,
train_transform: Optional[Union[Dict[str, Callable]]] = None,
val_transform: Optional[Union[Dict[str, Callable]]] = None,
test_transform: Optional[Union[Dict[str, Callable]]] = None,
predict_transform: Optional[Union[Dict[str, Callable]]] = None,
image_size: Union[int, Tuple[int, int]] = 256,
):
if val_transform:
raise_not_supported("validation")
if test_transform:
raise_not_supported("test")

if isinstance(image_size, int):
image_size = (image_size, image_size)
super().__init__(
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
image_size=image_size,
)

@functools.partial(_apply_to_input, keys=DefaultDataKeys.INPUT)
def default_transforms(self) -> Optional[Dict[str, Callable]]:
if self.training:
return dict(
to_tensor_transform=T.ToTensor(),
per_sample_transform_on_device=nn.Sequential(
T.Resize(self.image_size),
T.CenterCrop(self.image_size),
),
)
elif self.predicting:
return dict(
pre_tensor_transform=T.Resize(self.image_size),
to_tensor_transform=T.ToTensor(),
)
else:
# Style transfer doesn't support a validation or test phase, so we return nothing here
return None


class StyleTransferData(ImageClassificationData):
preprocess_cls = StyleTransferPreprocess

@classmethod
def from_folders(
cls,
train_folder: Optional[Union[str, pathlib.Path]] = None,
predict_folder: Optional[Union[str, pathlib.Path]] = None,
train_transform: Optional[Union[str, Dict]] = None,
predict_transform: Optional[Union[str, Dict]] = None,
preprocess: Optional[Preprocess] = None,
**kwargs: Any,
) -> "StyleTransferData":

if any(param in kwargs for param in ("val_folder", "val_transform")):
raise_not_supported("validation")

if any(param in kwargs for param in ("test_folder", "test_transform")):
raise_not_supported("test")

preprocess = preprocess or cls.preprocess_cls(
train_transform=train_transform,
predict_transform=predict_transform,
)

return cls.from_data_source(
DefaultDataSources.FOLDERS,
train_data=train_folder,
predict_data=predict_folder,
preprocess=preprocess,
**kwargs,
)
Loading