This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[WIP] add style transfer task with pystiche (#262)
* add style transfer task with pystiche * address review comments * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix type hint * allow passing style_image by path * add batch_size * add data_module based on image classification * add internal pre / post-processing * bail out if val / test step is performed * update example * move example from predict to finetuning * remove metrics from task * flake8 * remove unused imports * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove grayscale handling * address review comments and small fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * streamline apply_to_input * fix hyper parameters saving * implement custom step * cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add explanation to not supported phases * temporarily use unreleased pystiche version * add missing transforms in preprocess * introduce multi layer encoders as backbones * refactor task * add explanation for modified gram operator * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * streamline default transforms * add disabled test for finetuning example * add documentation skeleton * update changelog * update * update * update * update * update * update * update * update * update * update * update * update * update * update * change skipif * update * update * update * fix image size for preprocess * fix style transfer requirements * update * update doc * fix style transfer requirements * update * add reference to pystiche * remove unnecessary import Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: tchaton <[email protected]>
- Loading branch information
1 parent
c28bafa
commit 7c89fc1
Showing
35 changed files
with
559 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
############## | ||
Style Transfer | ||
############## | ||
|
||
******** | ||
The task | ||
******** | ||
|
||
The Neural Style Transfer Task is an optimization method which extract the style from an image and apply it another image while preserving its content. | ||
The goal is that the output image looks like the content image, but “painted” in the style of the style reference image. | ||
|
||
.. image:: https://raw.githubusercontent.com/pystiche/pystiche/master/docs/source/graphics/banner/banner.jpg | ||
:alt: style_transfer_example | ||
|
||
Lightning Flash :class:`~flash.image.style_transfer.StyleTransfer` and | ||
:class:`~flash.image.style_transfer.StyleTransferData` internally rely on `pystiche <https://pystiche.org>`_ as | ||
backend. | ||
|
||
------ | ||
|
||
*** | ||
Fit | ||
*** | ||
|
||
First, you would have to import the :class:`~flash.image.style_transfer.StyleTransfer` | ||
and :class:`~flash.image.style_transfer.StyleTransferData` from Flash. | ||
|
||
.. testcode:: style_transfer | ||
|
||
import flash | ||
from flash.core.data.utils import download_data | ||
from flash.image.style_transfer import StyleTransfer, StyleTransferData | ||
import pystiche | ||
|
||
|
||
Then, download some content images and create a :class:`~flash.image.style_transfer.StyleTransferData` DataModule. | ||
|
||
.. testcode:: style_transfer | ||
|
||
download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") | ||
|
||
data_module = StyleTransferData.from_folders(train_folder="data/coco128/images", batch_size=4) | ||
|
||
|
||
Select a style image and pass it to the `StyleTransfer` task. | ||
|
||
.. testcode:: style_transfer | ||
|
||
style_image = pystiche.demo.images()["paint"].read(size=256) | ||
|
||
model = StyleTransfer(style_image) | ||
|
||
Finally, create a Flash :class:`flash.core.trainer.Trainer` and pass it the model and datamodule. | ||
|
||
.. testcode:: style_transfer | ||
|
||
trainer = flash.Trainer(max_epochs=2) | ||
trainer.fit(model, data_module) | ||
|
||
.. testoutput:: | ||
:hide: | ||
|
||
... | ||
|
||
|
||
------ | ||
|
||
************* | ||
API reference | ||
************* | ||
|
||
StyleTransfer | ||
------------- | ||
|
||
.. autoclass:: flash.image.StyleTransfer | ||
:members: | ||
:exclude-members: forward | ||
|
||
StyleTransferData | ||
----------------- | ||
|
||
.. autoclass:: flash.image.StyleTransferData |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from flash.image.style_transfer.backbone import STYLE_TRANSFER_BACKBONES | ||
from flash.image.style_transfer.data import StyleTransferData, StyleTransferPreprocess | ||
from flash.image.style_transfer.model import StyleTransfer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import re | ||
|
||
from flash.core.registry import FlashRegistry | ||
from flash.core.utilities.imports import _PYSTICHE_AVAILABLE | ||
|
||
STYLE_TRANSFER_BACKBONES = FlashRegistry("backbones") | ||
|
||
__all__ = ["STYLE_TRANSFER_BACKBONES"] | ||
|
||
if _PYSTICHE_AVAILABLE: | ||
|
||
from pystiche import enc | ||
|
||
MLE_FN_PATTERN = re.compile(r"^(?P<name>\w+?)_multi_layer_encoder$") | ||
|
||
STYLE_TRANSFER_BACKBONES = FlashRegistry("backbones") | ||
|
||
for mle_fn in dir(enc): | ||
match = MLE_FN_PATTERN.match(mle_fn) | ||
if not match: | ||
continue | ||
|
||
STYLE_TRANSFER_BACKBONES( | ||
fn=lambda: (getattr(enc, mle_fn)(), None), | ||
name=match.group("name"), | ||
namespace="image/style_transfer", | ||
package="pystiche", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import functools | ||
import pathlib | ||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union | ||
|
||
from torch import nn | ||
|
||
from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources | ||
from flash.core.data.process import Preprocess | ||
from flash.core.data.transforms import ApplyToKeys | ||
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE | ||
from flash.image.classification import ImageClassificationData | ||
from flash.image.data import ImageNumpyDataSource, ImagePathsDataSource, ImageTensorDataSource | ||
from flash.image.style_transfer.utils import raise_not_supported | ||
|
||
if _TORCHVISION_AVAILABLE: | ||
from torchvision import transforms as T | ||
|
||
__all__ = ["StyleTransferPreprocess", "StyleTransferData"] | ||
|
||
|
||
def _apply_to_input(default_transforms_fn, keys: Union[Sequence[DefaultDataKeys], | ||
DefaultDataKeys]) -> Callable[..., Dict[str, ApplyToKeys]]: | ||
|
||
@functools.wraps(default_transforms_fn) | ||
def wrapper(*args: Any, **kwargs: Any) -> Optional[Dict[str, ApplyToKeys]]: | ||
default_transforms = default_transforms_fn(*args, **kwargs) | ||
if not default_transforms: | ||
return default_transforms | ||
|
||
return {hook: ApplyToKeys(keys, transform) for hook, transform in default_transforms.items()} | ||
|
||
return wrapper | ||
|
||
|
||
class StyleTransferPreprocess(Preprocess): | ||
|
||
def __init__( | ||
self, | ||
train_transform: Optional[Union[Dict[str, Callable]]] = None, | ||
val_transform: Optional[Union[Dict[str, Callable]]] = None, | ||
test_transform: Optional[Union[Dict[str, Callable]]] = None, | ||
predict_transform: Optional[Union[Dict[str, Callable]]] = None, | ||
image_size: int = 256, | ||
): | ||
if val_transform: | ||
raise_not_supported("validation") | ||
if test_transform: | ||
raise_not_supported("test") | ||
|
||
if isinstance(image_size, int): | ||
image_size = (image_size, image_size) | ||
|
||
self.image_size = image_size | ||
|
||
super().__init__( | ||
train_transform=train_transform, | ||
val_transform=val_transform, | ||
test_transform=test_transform, | ||
predict_transform=predict_transform, | ||
data_sources={ | ||
DefaultDataSources.FILES: ImagePathsDataSource(), | ||
DefaultDataSources.FOLDERS: ImagePathsDataSource(), | ||
DefaultDataSources.NUMPY: ImageNumpyDataSource(), | ||
DefaultDataSources.TENSORS: ImageTensorDataSource(), | ||
DefaultDataSources.TENSORS: ImageTensorDataSource(), | ||
}, | ||
default_data_source=DefaultDataSources.FILES, | ||
) | ||
|
||
def get_state_dict(self) -> Dict[str, Any]: | ||
return {**self.transforms, "image_size": self.image_size} | ||
|
||
@classmethod | ||
def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): | ||
return cls(**state_dict) | ||
|
||
@functools.partial(_apply_to_input, keys=DefaultDataKeys.INPUT) | ||
def default_transforms(self) -> Optional[Dict[str, Callable]]: | ||
if self.training: | ||
return dict( | ||
to_tensor_transform=T.ToTensor(), | ||
per_sample_transform_on_device=nn.Sequential( | ||
T.Resize(self.image_size), | ||
T.CenterCrop(self.image_size), | ||
), | ||
) | ||
elif self.predicting: | ||
return dict( | ||
pre_tensor_transform=T.Resize(self.image_size), | ||
to_tensor_transform=T.ToTensor(), | ||
) | ||
# Style transfer doesn't support a validation or test phase, so we return nothing here | ||
return None | ||
|
||
|
||
class StyleTransferData(ImageClassificationData): | ||
preprocess_cls = StyleTransferPreprocess | ||
|
||
@classmethod | ||
def from_folders( | ||
cls, | ||
train_folder: Optional[Union[str, pathlib.Path]] = None, | ||
predict_folder: Optional[Union[str, pathlib.Path]] = None, | ||
train_transform: Optional[Union[str, Dict]] = None, | ||
predict_transform: Optional[Union[str, Dict]] = None, | ||
preprocess: Optional[Preprocess] = None, | ||
**kwargs: Any, | ||
) -> "StyleTransferData": | ||
|
||
if any(param in kwargs for param in ("val_folder", "val_transform")): | ||
raise_not_supported("validation") | ||
|
||
if any(param in kwargs for param in ("test_folder", "test_transform")): | ||
raise_not_supported("test") | ||
|
||
preprocess = preprocess or cls.preprocess_cls( | ||
train_transform=train_transform, | ||
predict_transform=predict_transform, | ||
) | ||
|
||
return cls.from_data_source( | ||
DefaultDataSources.FOLDERS, | ||
train_data=train_folder, | ||
predict_data=predict_folder, | ||
preprocess=preprocess, | ||
**kwargs, | ||
) |
Oops, something went wrong.