Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[WIP] add style transfer task with pystiche (#262)
Browse files Browse the repository at this point in the history
* add style transfer task with pystiche

* address review comments

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix type hint

* allow passing style_image by path

* add batch_size

* add data_module based on image classification

* add internal pre / post-processing

* bail out if val / test step is performed

* update example

* move example from predict to finetuning

* remove metrics from task

* flake8

* remove unused imports

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove grayscale handling

* address review comments and small fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* streamline apply_to_input

* fix hyper parameters saving

* implement custom step

* cleanup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add explanation to not supported phases

* temporarily use unreleased pystiche version

* add missing transforms in preprocess

* introduce multi layer encoders as backbones

* refactor task

* add explanation for modified gram operator

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* streamline default transforms

* add disabled test for finetuning example

* add documentation skeleton

* update changelog

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* change skipif

* update

* update

* update

* fix image size for preprocess

* fix style transfer requirements

* update

* update doc

* fix style transfer requirements

* update

* add reference to pystiche

* remove unnecessary import

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: tchaton <[email protected]>
  • Loading branch information
3 people authored May 17, 2021
1 parent c28bafa commit 7c89fc1
Show file tree
Hide file tree
Showing 35 changed files with 559 additions and 12 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ jobs:
python-version: 3.8
requires: 'latest'
topic: 'text'
- os: ubuntu-20.04
python-version: 3.8
requires: 'latest'
topic: 'image_style_transfer'

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 35
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Refactor preprocess_cls to preprocess, add Serializer, add DataPipelineState ([#229](https://github.com/PyTorchLightning/lightning-flash/pull/229))
- Added Semantic Segmentation task ([#239](https://github.com/PyTorchLightning/lightning-flash/pull/239) [#287](https://github.com/PyTorchLightning/lightning-flash/pull/287) [#290](https://github.com/PyTorchLightning/lightning-flash/pull/290))
- Added Object detection prediction example ([#283](https://github.com/PyTorchLightning/lightning-flash/pull/283))
- Added Style Transfer task and accompanying finetuning and prediction examples ([#262](https://github.com/PyTorchLightning/lightning-flash/pull/262))

### Changed

Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Lightning Flash
reference/object_detection
reference/video_classification
reference/semantic_segmentation
reference/style_transfer

.. toctree::
:maxdepth: 1
Expand Down
82 changes: 82 additions & 0 deletions docs/source/reference/style_transfer.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
##############
Style Transfer
##############

********
The task
********

The Neural Style Transfer Task is an optimization method which extract the style from an image and apply it another image while preserving its content.
The goal is that the output image looks like the content image, but “painted” in the style of the style reference image.

.. image:: https://raw.githubusercontent.com/pystiche/pystiche/master/docs/source/graphics/banner/banner.jpg
:alt: style_transfer_example

Lightning Flash :class:`~flash.image.style_transfer.StyleTransfer` and
:class:`~flash.image.style_transfer.StyleTransferData` internally rely on `pystiche <https://pystiche.org>`_ as
backend.

------

***
Fit
***

First, you would have to import the :class:`~flash.image.style_transfer.StyleTransfer`
and :class:`~flash.image.style_transfer.StyleTransferData` from Flash.

.. testcode:: style_transfer

import flash
from flash.core.data.utils import download_data
from flash.image.style_transfer import StyleTransfer, StyleTransferData
import pystiche


Then, download some content images and create a :class:`~flash.image.style_transfer.StyleTransferData` DataModule.

.. testcode:: style_transfer

download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/")

data_module = StyleTransferData.from_folders(train_folder="data/coco128/images", batch_size=4)


Select a style image and pass it to the `StyleTransfer` task.

.. testcode:: style_transfer

style_image = pystiche.demo.images()["paint"].read(size=256)

model = StyleTransfer(style_image)

Finally, create a Flash :class:`flash.core.trainer.Trainer` and pass it the model and datamodule.

.. testcode:: style_transfer

trainer = flash.Trainer(max_epochs=2)
trainer.fit(model, data_module)

.. testoutput::
:hide:

...


------

*************
API reference
*************

StyleTransfer
-------------

.. autoclass:: flash.image.StyleTransfer
:members:
:exclude-members: forward

StyleTransferData
-----------------

.. autoclass:: flash.image.StyleTransferData
7 changes: 5 additions & 2 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ def _resolve_function_hierarchy(
if object_type is None:
object_type = Preprocess

prefixes = ['']
prefixes = []

if stage in (RunningStage.TRAINING, RunningStage.TUNING):
prefixes += ['train', 'fit']
elif stage == RunningStage.VALIDATING:
Expand All @@ -190,9 +191,11 @@ def _resolve_function_hierarchy(
elif stage == RunningStage.PREDICTING:
prefixes += ['predict']

prefixes += [None]

for prefix in prefixes:
if cls._is_overriden(function_name, process_obj, object_type, prefix=prefix):
return f'{prefix}_{function_name}'
return function_name if prefix is None else f'{prefix}_{function_name}'

return function_name

Expand Down
10 changes: 3 additions & 7 deletions flash/core/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,14 +298,10 @@ def generate_dataset(

mock_dataset = typing.cast(AutoDataset, MockDataset())
with CurrentRunningStageFuncContext(running_stage, "load_data", self):
load_data: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr(
self, DataPipeline._resolve_function_hierarchy(
"load_data",
self,
running_stage,
DataSource,
)
resolved_func_name = DataPipeline._resolve_function_hierarchy(
"load_data", self, running_stage, DataSource
)
load_data: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr(self, resolved_func_name)
parameters = signature(load_data).parameters
if len(parameters) > 1 and "dataset" in parameters: # TODO: This was DATASET_KEY before
data = load_data(data, mock_dataset)
Expand Down
5 changes: 5 additions & 0 deletions flash/core/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]:
return result
return x

def __repr__(self):
keys = self.keys[0] if len(self.keys) == 1 else self.keys
transform = [c for c in self.children()]
return f"{self.__class__.__name__}(keys={keys}, transform={transform})"


class KorniaParallelTransforms(nn.Sequential):
"""The ``KorniaParallelTransforms`` class is an ``nn.Sequential`` which will apply the given transforms to each
Expand Down
3 changes: 3 additions & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,13 @@ def _compare_version(package: str, op, version) -> bool:
_PYTORCHVIDEO_AVAILABLE = _module_available("pytorchvideo")
_MATPLOTLIB_AVAILABLE = _module_available("matplotlib")
_TRANSFORMERS_AVAILABLE = _module_available("transformers")
_PYSTICHE_AVAILABLE = _module_available("pystiche")

if Version:
_TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0")
_PYSTICHE_GREATER_EQUAL_0_7_2 = _compare_version("pystiche", operator.ge, "0.7.2")

_IMAGE_STLYE_TRANSFER = _PYSTICHE_AVAILABLE
_TEXT_AVAILABLE = _TRANSFORMERS_AVAILABLE
_TABULAR_AVAILABLE = _TABNET_AVAILABLE and _PANDAS_AVAILABLE
_VIDEO_AVAILABLE = _PYTORCHVIDEO_AVAILABLE
Expand Down
1 change: 1 addition & 0 deletions flash/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from flash.image.detection import ObjectDetectionData, ObjectDetector
from flash.image.embedding import ImageEmbedder
from flash.image.segmentation import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationPreprocess
from flash.image.style_transfer import StyleTransfer, StyleTransferData, StyleTransferPreprocess
3 changes: 3 additions & 0 deletions flash/image/style_transfer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from flash.image.style_transfer.backbone import STYLE_TRANSFER_BACKBONES
from flash.image.style_transfer.data import StyleTransferData, StyleTransferPreprocess
from flash.image.style_transfer.model import StyleTransfer
28 changes: 28 additions & 0 deletions flash/image/style_transfer/backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import re

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _PYSTICHE_AVAILABLE

STYLE_TRANSFER_BACKBONES = FlashRegistry("backbones")

__all__ = ["STYLE_TRANSFER_BACKBONES"]

if _PYSTICHE_AVAILABLE:

from pystiche import enc

MLE_FN_PATTERN = re.compile(r"^(?P<name>\w+?)_multi_layer_encoder$")

STYLE_TRANSFER_BACKBONES = FlashRegistry("backbones")

for mle_fn in dir(enc):
match = MLE_FN_PATTERN.match(mle_fn)
if not match:
continue

STYLE_TRANSFER_BACKBONES(
fn=lambda: (getattr(enc, mle_fn)(), None),
name=match.group("name"),
namespace="image/style_transfer",
package="pystiche",
)
127 changes: 127 additions & 0 deletions flash/image/style_transfer/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import functools
import pathlib
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

from torch import nn

from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources
from flash.core.data.process import Preprocess
from flash.core.data.transforms import ApplyToKeys
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE
from flash.image.classification import ImageClassificationData
from flash.image.data import ImageNumpyDataSource, ImagePathsDataSource, ImageTensorDataSource
from flash.image.style_transfer.utils import raise_not_supported

if _TORCHVISION_AVAILABLE:
from torchvision import transforms as T

__all__ = ["StyleTransferPreprocess", "StyleTransferData"]


def _apply_to_input(default_transforms_fn, keys: Union[Sequence[DefaultDataKeys],
DefaultDataKeys]) -> Callable[..., Dict[str, ApplyToKeys]]:

@functools.wraps(default_transforms_fn)
def wrapper(*args: Any, **kwargs: Any) -> Optional[Dict[str, ApplyToKeys]]:
default_transforms = default_transforms_fn(*args, **kwargs)
if not default_transforms:
return default_transforms

return {hook: ApplyToKeys(keys, transform) for hook, transform in default_transforms.items()}

return wrapper


class StyleTransferPreprocess(Preprocess):

def __init__(
self,
train_transform: Optional[Union[Dict[str, Callable]]] = None,
val_transform: Optional[Union[Dict[str, Callable]]] = None,
test_transform: Optional[Union[Dict[str, Callable]]] = None,
predict_transform: Optional[Union[Dict[str, Callable]]] = None,
image_size: int = 256,
):
if val_transform:
raise_not_supported("validation")
if test_transform:
raise_not_supported("test")

if isinstance(image_size, int):
image_size = (image_size, image_size)

self.image_size = image_size

super().__init__(
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_sources={
DefaultDataSources.FILES: ImagePathsDataSource(),
DefaultDataSources.FOLDERS: ImagePathsDataSource(),
DefaultDataSources.NUMPY: ImageNumpyDataSource(),
DefaultDataSources.TENSORS: ImageTensorDataSource(),
DefaultDataSources.TENSORS: ImageTensorDataSource(),
},
default_data_source=DefaultDataSources.FILES,
)

def get_state_dict(self) -> Dict[str, Any]:
return {**self.transforms, "image_size": self.image_size}

@classmethod
def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
return cls(**state_dict)

@functools.partial(_apply_to_input, keys=DefaultDataKeys.INPUT)
def default_transforms(self) -> Optional[Dict[str, Callable]]:
if self.training:
return dict(
to_tensor_transform=T.ToTensor(),
per_sample_transform_on_device=nn.Sequential(
T.Resize(self.image_size),
T.CenterCrop(self.image_size),
),
)
elif self.predicting:
return dict(
pre_tensor_transform=T.Resize(self.image_size),
to_tensor_transform=T.ToTensor(),
)
# Style transfer doesn't support a validation or test phase, so we return nothing here
return None


class StyleTransferData(ImageClassificationData):
preprocess_cls = StyleTransferPreprocess

@classmethod
def from_folders(
cls,
train_folder: Optional[Union[str, pathlib.Path]] = None,
predict_folder: Optional[Union[str, pathlib.Path]] = None,
train_transform: Optional[Union[str, Dict]] = None,
predict_transform: Optional[Union[str, Dict]] = None,
preprocess: Optional[Preprocess] = None,
**kwargs: Any,
) -> "StyleTransferData":

if any(param in kwargs for param in ("val_folder", "val_transform")):
raise_not_supported("validation")

if any(param in kwargs for param in ("test_folder", "test_transform")):
raise_not_supported("test")

preprocess = preprocess or cls.preprocess_cls(
train_transform=train_transform,
predict_transform=predict_transform,
)

return cls.from_data_source(
DefaultDataSources.FOLDERS,
train_data=train_folder,
predict_data=predict_folder,
preprocess=preprocess,
**kwargs,
)
Loading

0 comments on commit 7c89fc1

Please sign in to comment.