From 413a53021feb99be5b9d095b4b1f3cb1dfd3160a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 5 May 2021 21:55:32 +0200 Subject: [PATCH 01/60] add style transfer task with pystiche --- flash/utils/imports.py | 1 + flash/vision/__init__.py | 1 + flash/vision/style_transfer/__init__.py | 2 + flash/vision/style_transfer/data.py | 1 + flash/vision/style_transfer/model.py | 90 ++++++++++++++++++++++++ flash_examples/predict/style_transfer.py | 47 +++++++++++++ requirements.txt | 1 + 7 files changed, 143 insertions(+) create mode 100644 flash/vision/style_transfer/__init__.py create mode 100644 flash/vision/style_transfer/data.py create mode 100644 flash/vision/style_transfer/model.py create mode 100644 flash_examples/predict/style_transfer.py diff --git a/flash/utils/imports.py b/flash/utils/imports.py index d50ec7bac8..c7f9c08c4f 100644 --- a/flash/utils/imports.py +++ b/flash/utils/imports.py @@ -8,3 +8,4 @@ _PYTORCHVIDEO_AVAILABLE = _module_available("pytorchvideo") _MATPLOTLIB_AVAILABLE = _module_available("matplotlib") _TRANSFORMERS_AVAILABLE = _module_available("transformers") +_PYSTICHE_AVAILABLE = _module_available("pystiche") diff --git a/flash/vision/__init__.py b/flash/vision/__init__.py index 39dce803d8..03e2b012b6 100644 --- a/flash/vision/__init__.py +++ b/flash/vision/__init__.py @@ -2,3 +2,4 @@ from flash.vision.classification import ImageClassificationData, ImageClassificationPreprocess, ImageClassifier from flash.vision.detection import ObjectDetectionData, ObjectDetector from flash.vision.embedding import ImageEmbedder +from .style_transfer import * diff --git a/flash/vision/style_transfer/__init__.py b/flash/vision/style_transfer/__init__.py new file mode 100644 index 0000000000..5d7af11b13 --- /dev/null +++ b/flash/vision/style_transfer/__init__.py @@ -0,0 +1,2 @@ +from .data import * +from .model import * diff --git a/flash/vision/style_transfer/data.py b/flash/vision/style_transfer/data.py new file mode 100644 index 0000000000..464090415c --- /dev/null +++ b/flash/vision/style_transfer/data.py @@ -0,0 +1 @@ +# TODO diff --git a/flash/vision/style_transfer/model.py b/flash/vision/style_transfer/model.py new file mode 100644 index 0000000000..aaaa3b85b8 --- /dev/null +++ b/flash/vision/style_transfer/model.py @@ -0,0 +1,90 @@ +from torch import nn +from torch.nn.functional import interpolate + +__all__ = ["Transformer"] + + +class Interpolate(nn.Module): + def __init__(self, scale_factor=1.0, mode="nearest"): + super().__init__() + self.scale_factor = scale_factor + self.mode = mode + + def forward(self, input): + return interpolate(input, scale_factor=self.scale_factor, mode=self.mode) + + def extra_repr(self): + extras = [] + if self.scale_factor: + extras.append(f"scale_factor={self.scale_factor}") + if self.mode != "nearest": + extras.append(f"mode={self.mode}") + return ", ".join(extras) + + +class Conv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + upsample=False, + norm=True, + activation=True, + ): + super().__init__() + self.upsample = Interpolate(scale_factor=stride) if upsample else None + self.pad = nn.ReflectionPad2d(kernel_size // 2) + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size, stride=1 if upsample else stride + ) + self.norm = nn.InstanceNorm2d(out_channels, affine=True) if norm else None + self.activation = nn.ReLU() if activation else None + + def forward(self, input): + if self.upsample: + input = self.upsample(input) + + output = self.conv(self.pad(input)) + + if self.norm: + output = self.norm(output) + if self.activation: + output = self.activation(output) + + return output + + +class Residual(nn.Module): + def __init__(self, channels): + super().__init__() + self.conv1 = Conv(channels, channels, kernel_size=3) + self.conv2 = Conv(channels, channels, kernel_size=3, activation=False) + + def forward(self, input): + output = self.conv2(self.conv1(input)) + return output + input + + +class Transformer(nn.Module): + def __init__(self): + super().__init__() + self.encoder = nn.Sequential( + Conv(3, 32, kernel_size=9), + Conv(32, 64, kernel_size=3, stride=2), + Conv(64, 128, kernel_size=3, stride=2), + Residual(128), + Residual(128), + Residual(128), + Residual(128), + Residual(128), + ) + self.decoder = nn.Sequential( + Conv(128, 64, kernel_size=3, stride=2, upsample=True), + Conv(64, 32, kernel_size=3, stride=2, upsample=True), + Conv(32, 3, kernel_size=9, norm=False, activation=False), + ) + + def forward(self, input): + return self.decoder(self.encoder(input)) \ No newline at end of file diff --git a/flash_examples/predict/style_transfer.py b/flash_examples/predict/style_transfer.py new file mode 100644 index 0000000000..e186254135 --- /dev/null +++ b/flash_examples/predict/style_transfer.py @@ -0,0 +1,47 @@ +import sys + +import torch + +from flash.utils.imports import _PYSTICHE_AVAILABLE + +if _PYSTICHE_AVAILABLE: + from pystiche import enc, loss, ops +else: + print("Please, run `pip install pystiche`") + sys.exit(0) + +multi_layer_encoder = enc.vgg16_multi_layer_encoder() + +content_layer = "relu2_2" +content_encoder = multi_layer_encoder.extract_encoder(content_layer) +content_weight = 1e5 +content_loss = ops.FeatureReconstructionOperator( + content_encoder, score_weight=content_weight +) + + +class GramOperator(ops.GramOperator): + def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor: + repr = super().enc_to_repr(enc) + num_channels = repr.size()[1] + return repr / num_channels + + +style_layers = ("relu1_2", "relu2_2", "relu3_3", "relu4_3") +style_weight = 1e10 +style_loss = ops.MultiLayerEncodingOperator( + multi_layer_encoder, + style_layers, + lambda encoder, layer_weight: GramOperator(encoder, score_weight=layer_weight), + layer_weights="sum", + score_weight=style_weight, +) + +# TODO: this needs to be moved to the device to be trained on +# TODO: we need to register a style image here +perceptual_loss = loss.PerceptualLoss(content_loss, style_loss) + + +def loss_fn(image): + perceptual_loss.set_content_image(image) + return float(perceptual_loss(image)) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 8aaa1ec97d..a611142730 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,4 @@ pycocotools>=2.0.2 ; python_version >= "3.7" kornia==0.5.0 pytorchvideo matplotlib # used by the visualisation callback +pystiche>=0.7.1 From be5a893a6d9a033dbb3d0d4e6b0fdc8f007d83e0 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 10 May 2021 09:12:24 +0200 Subject: [PATCH 02/60] address review comments --- flash/vision/__init__.py | 2 +- flash/vision/style_transfer/__init__.py | 3 +- flash/vision/style_transfer/data.py | 1 - flash/vision/style_transfer/model.py | 134 +++++++++++++++++++---- flash_examples/predict/style_transfer.py | 44 ++------ 5 files changed, 123 insertions(+), 61 deletions(-) delete mode 100644 flash/vision/style_transfer/data.py diff --git a/flash/vision/__init__.py b/flash/vision/__init__.py index 03e2b012b6..81c24bf13e 100644 --- a/flash/vision/__init__.py +++ b/flash/vision/__init__.py @@ -2,4 +2,4 @@ from flash.vision.classification import ImageClassificationData, ImageClassificationPreprocess, ImageClassifier from flash.vision.detection import ObjectDetectionData, ObjectDetector from flash.vision.embedding import ImageEmbedder -from .style_transfer import * +from flash.vision.style_transfer import StyleTransfer diff --git a/flash/vision/style_transfer/__init__.py b/flash/vision/style_transfer/__init__.py index 5d7af11b13..aa70fa9459 100644 --- a/flash/vision/style_transfer/__init__.py +++ b/flash/vision/style_transfer/__init__.py @@ -1,2 +1 @@ -from .data import * -from .model import * +from flash.vision.style_transfer.model import StyleTransfer diff --git a/flash/vision/style_transfer/data.py b/flash/vision/style_transfer/data.py deleted file mode 100644 index 464090415c..0000000000 --- a/flash/vision/style_transfer/data.py +++ /dev/null @@ -1 +0,0 @@ -# TODO diff --git a/flash/vision/style_transfer/model.py b/flash/vision/style_transfer/model.py index aaaa3b85b8..4e918f556a 100644 --- a/flash/vision/style_transfer/model.py +++ b/flash/vision/style_transfer/model.py @@ -1,22 +1,29 @@ +from typing import Any, Dict, Mapping, Optional, Sequence, Type, Union + +import torch +import torchmetrics +from pystiche import enc, loss, ops from torch import nn from torch.nn.functional import interpolate +from torch.optim.lr_scheduler import _LRScheduler + +from flash.core import Task +from flash.data.process import Serializer -__all__ = ["Transformer"] +__all__ = ["StyleTransfer"] class Interpolate(nn.Module): - def __init__(self, scale_factor=1.0, mode="nearest"): + def __init__(self, scale_factor: float = 1.0, mode: str = "nearest") -> None: super().__init__() self.scale_factor = scale_factor self.mode = mode - def forward(self, input): + def forward(self, input: torch.Tensor) -> torch.Tensor: return interpolate(input, scale_factor=self.scale_factor, mode=self.mode) - def extra_repr(self): - extras = [] - if self.scale_factor: - extras.append(f"scale_factor={self.scale_factor}") + def extra_repr(self) -> str: + extras = [f"scale_factor={self.scale_factor}"] if self.mode != "nearest": extras.append(f"mode={self.mode}") return ", ".join(extras) @@ -25,24 +32,23 @@ def extra_repr(self): class Conv(nn.Module): def __init__( self, - in_channels, - out_channels, - kernel_size, - stride=1, - upsample=False, - norm=True, - activation=True, + in_channels: int, + out_channels: int, + kernel_size: int, + *, + stride: int = 1, + upsample: bool = False, + norm: bool = True, + activation: bool = True, ): super().__init__() self.upsample = Interpolate(scale_factor=stride) if upsample else None self.pad = nn.ReflectionPad2d(kernel_size // 2) - self.conv = nn.Conv2d( - in_channels, out_channels, kernel_size, stride=1 if upsample else stride - ) + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1 if upsample else stride) self.norm = nn.InstanceNorm2d(out_channels, affine=True) if norm else None self.activation = nn.ReLU() if activation else None - def forward(self, input): + def forward(self, input: torch.Tensor) -> torch.Tensor: if self.upsample: input = self.upsample(input) @@ -57,18 +63,18 @@ def forward(self, input): class Residual(nn.Module): - def __init__(self, channels): + def __init__(self, channels: int) -> None: super().__init__() self.conv1 = Conv(channels, channels, kernel_size=3) self.conv2 = Conv(channels, channels, kernel_size=3, activation=False) - def forward(self, input): + def forward(self, input: torch.Tensor) -> torch.Tensor: output = self.conv2(self.conv1(input)) return output + input class Transformer(nn.Module): - def __init__(self): + def __init__(self) -> None: super().__init__() self.encoder = nn.Sequential( Conv(3, 32, kernel_size=9), @@ -86,5 +92,87 @@ def __init__(self): Conv(32, 3, kernel_size=9, norm=False, activation=False), ) - def forward(self, input): - return self.decoder(self.encoder(input)) \ No newline at end of file + def forward(self, input: torch.Tensor) -> torch.Tensor: + return self.decoder(self.encoder(input)) + + +class StyleTransfer(Task): + def __init__( + self, + style_image: torch.Tensor, + model: Optional[nn.Module] = None, + multi_layer_encoder: Optional = None, + content_loss: Optional[Union[ops.ComparisonOperator, ops.OperatorContainer]] = None, + style_loss: Optional[Union[ops.ComparisonOperator, ops.OperatorContainer]] = None, + optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, + scheduler_kwargs: Optional[Dict[str, Any]] = None, + metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, + learning_rate: float = 1e-3, + serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + ): + if multi_layer_encoder is None: + multi_layer_encoder = self.default_multi_layer_encoder() + + if content_loss is None: + content_loss = self.default_content_loss(multi_layer_encoder) + + if style_loss is None: + style_loss = self.default_style_loss(multi_layer_encoder) + + self.perceptual_loss = loss.PerceptualLoss(content_loss, style_loss) + self.perceptual_loss.set_style_image(style_image) + + self.save_hyperparameters() + + super().__init__( + model=model, + loss_fn=self.perceptual_loss, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, + scheduler=scheduler, + scheduler_kwargs=scheduler_kwargs, + metrics=metrics, + learning_rate=learning_rate, + serializer=serializer, + ) + + def default_multi_layer_encoder(self) -> enc.MultiLayerEncoder: + return enc.vgg16_multi_layer_encoder() + + def default_content_loss( + self, multi_layer_encoder: Optional[enc.MultiLayerEncoder] = None + ) -> ops.FeatureReconstructionOperator: + if multi_layer_encoder is None: + multi_layer_encoder = self.default_multi_layer_encoder() + content_layer = "relu2_2" + content_encoder = multi_layer_encoder.extract_encoder(content_layer) + content_weight = 1e5 + return ops.FeatureReconstructionOperator(content_encoder, score_weight=content_weight) + + def default_style_loss( + self, multi_layer_encoder: Optional[enc.MultiLayerEncoder] = None + ) -> ops.MultiLayerEncodingOperator: + class GramOperator(ops.GramOperator): + def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor: + repr = super().enc_to_repr(enc) + num_channels = repr.size()[1] + return repr / num_channels + + if multi_layer_encoder is None: + multi_layer_encoder = self.default_multi_layer_encoder() + + style_layers = ("relu1_2", "relu2_2", "relu3_3", "relu4_3") + style_weight = 1e10 + return ops.MultiLayerEncodingOperator( + multi_layer_encoder, + style_layers, + lambda encoder, layer_weight: GramOperator(encoder, score_weight=layer_weight), + layer_weights="sum", + score_weight=style_weight, + ) + + def forward(self, content_image: torch.Tensor) -> torch.Tensor: + self.perceptual_loss.set_content_image(content_image) + return self.model(content_image) diff --git a/flash_examples/predict/style_transfer.py b/flash_examples/predict/style_transfer.py index e186254135..a83f7d6749 100644 --- a/flash_examples/predict/style_transfer.py +++ b/flash_examples/predict/style_transfer.py @@ -1,47 +1,23 @@ import sys -import torch - +import flash +from flash.data.utils import download_data from flash.utils.imports import _PYSTICHE_AVAILABLE +from flash.vision.style_transfer import StyleTransfer if _PYSTICHE_AVAILABLE: - from pystiche import enc, loss, ops + import pystiche.demo else: print("Please, run `pip install pystiche`") sys.exit(0) -multi_layer_encoder = enc.vgg16_multi_layer_encoder() - -content_layer = "relu2_2" -content_encoder = multi_layer_encoder.extract_encoder(content_layer) -content_weight = 1e5 -content_loss = ops.FeatureReconstructionOperator( - content_encoder, score_weight=content_weight -) - - -class GramOperator(ops.GramOperator): - def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor: - repr = super().enc_to_repr(enc) - num_channels = repr.size()[1] - return repr / num_channels - +download_data("http://images.cocodataset.org/zips/train2014.zip", "data") -style_layers = ("relu1_2", "relu2_2", "relu3_3", "relu4_3") -style_weight = 1e10 -style_loss = ops.MultiLayerEncodingOperator( - multi_layer_encoder, - style_layers, - lambda encoder, layer_weight: GramOperator(encoder, score_weight=layer_weight), - layer_weights="sum", - score_weight=style_weight, -) +data_module = ImageUnsupervisedData.from_folder("data") -# TODO: this needs to be moved to the device to be trained on -# TODO: we need to register a style image here -perceptual_loss = loss.PerceptualLoss(content_loss, style_loss) +style_image = pystiche.demo.images()["paint"].read(size=256, edge="long") +model = StyleTransfer(style_image) -def loss_fn(image): - perceptual_loss.set_content_image(image) - return float(perceptual_loss(image)) \ No newline at end of file +trainer = flash.Trainer(max_epochs=2) +trainer.fit(model, data_module) From 398bbf355c25e01dbe35afa8e0e9fe1791c40d81 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 May 2021 07:12:44 +0000 Subject: [PATCH 03/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flash/vision/style_transfer/model.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/flash/vision/style_transfer/model.py b/flash/vision/style_transfer/model.py index 4e918f556a..8a64324d65 100644 --- a/flash/vision/style_transfer/model.py +++ b/flash/vision/style_transfer/model.py @@ -14,6 +14,7 @@ class Interpolate(nn.Module): + def __init__(self, scale_factor: float = 1.0, mode: str = "nearest") -> None: super().__init__() self.scale_factor = scale_factor @@ -30,6 +31,7 @@ def extra_repr(self) -> str: class Conv(nn.Module): + def __init__( self, in_channels: int, @@ -63,6 +65,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class Residual(nn.Module): + def __init__(self, channels: int) -> None: super().__init__() self.conv1 = Conv(channels, channels, kernel_size=3) @@ -74,6 +77,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class Transformer(nn.Module): + def __init__(self) -> None: super().__init__() self.encoder = nn.Sequential( @@ -97,6 +101,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class StyleTransfer(Task): + def __init__( self, style_image: torch.Tensor, @@ -154,7 +159,9 @@ def default_content_loss( def default_style_loss( self, multi_layer_encoder: Optional[enc.MultiLayerEncoder] = None ) -> ops.MultiLayerEncodingOperator: + class GramOperator(ops.GramOperator): + def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor: repr = super().enc_to_repr(enc) num_channels = repr.size()[1] From 61c948009793c5de3691367a2167ebd2b124ea05 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 10 May 2021 09:32:15 +0200 Subject: [PATCH 04/60] fix type hint --- flash/vision/style_transfer/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash/vision/style_transfer/model.py b/flash/vision/style_transfer/model.py index 8a64324d65..c051da1371 100644 --- a/flash/vision/style_transfer/model.py +++ b/flash/vision/style_transfer/model.py @@ -106,7 +106,7 @@ def __init__( self, style_image: torch.Tensor, model: Optional[nn.Module] = None, - multi_layer_encoder: Optional = None, + multi_layer_encoder: Optional[enc.MultiLayerEncoder] = None, content_loss: Optional[Union[ops.ComparisonOperator, ops.OperatorContainer]] = None, style_loss: Optional[Union[ops.ComparisonOperator, ops.OperatorContainer]] = None, optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, From 0a70150b1e844c01b8bde83f8d76ebb568f2ed82 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 10 May 2021 12:18:00 +0200 Subject: [PATCH 05/60] allow passing style_image by path --- flash/vision/style_transfer/model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/flash/vision/style_transfer/model.py b/flash/vision/style_transfer/model.py index c051da1371..9cfa177865 100644 --- a/flash/vision/style_transfer/model.py +++ b/flash/vision/style_transfer/model.py @@ -3,6 +3,7 @@ import torch import torchmetrics from pystiche import enc, loss, ops +from pystiche.image import read_image from torch import nn from torch.nn.functional import interpolate from torch.optim.lr_scheduler import _LRScheduler @@ -104,7 +105,7 @@ class StyleTransfer(Task): def __init__( self, - style_image: torch.Tensor, + style_image: Union[str, torch.Tensor], model: Optional[nn.Module] = None, multi_layer_encoder: Optional[enc.MultiLayerEncoder] = None, content_loss: Optional[Union[ops.ComparisonOperator, ops.OperatorContainer]] = None, @@ -117,6 +118,9 @@ def __init__( learning_rate: float = 1e-3, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, ): + if isinstance(style_image, str): + style_image = read_image(style_image) + if multi_layer_encoder is None: multi_layer_encoder = self.default_multi_layer_encoder() From f6b2fcc1da422503460954196c3db37d00c0cc0a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 10 May 2021 12:20:29 +0200 Subject: [PATCH 06/60] add batch_size --- flash_examples/predict/style_transfer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_examples/predict/style_transfer.py b/flash_examples/predict/style_transfer.py index a83f7d6749..d57ffb7438 100644 --- a/flash_examples/predict/style_transfer.py +++ b/flash_examples/predict/style_transfer.py @@ -13,7 +13,7 @@ download_data("http://images.cocodataset.org/zips/train2014.zip", "data") -data_module = ImageUnsupervisedData.from_folder("data") +data_module = ImageUnsupervisedData.from_folder("data", batch_size=4) style_image = pystiche.demo.images()["paint"].read(size=256, edge="long") From edf0dfff5030fd7757a42e18bfbc7bff813ba213 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 11 May 2021 11:18:03 +0200 Subject: [PATCH 07/60] add data_module based on image classification --- flash/vision/__init__.py | 2 +- flash/vision/classification/data.py | 10 ++- flash/vision/style_transfer/__init__.py | 1 + flash/vision/style_transfer/_utils.py | 7 ++ flash/vision/style_transfer/data.py | 95 +++++++++++++++++++++++++ 5 files changed, 112 insertions(+), 3 deletions(-) create mode 100644 flash/vision/style_transfer/_utils.py create mode 100644 flash/vision/style_transfer/data.py diff --git a/flash/vision/__init__.py b/flash/vision/__init__.py index 81c24bf13e..ddba3efebc 100644 --- a/flash/vision/__init__.py +++ b/flash/vision/__init__.py @@ -2,4 +2,4 @@ from flash.vision.classification import ImageClassificationData, ImageClassificationPreprocess, ImageClassifier from flash.vision.detection import ObjectDetectionData, ObjectDetector from flash.vision.embedding import ImageEmbedder -from flash.vision.style_transfer import StyleTransfer +from flash.vision.style_transfer import StyleTransferPreprocess, StyleTransferData, StyleTransfer diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index ac492461de..ce61a49ec1 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -160,6 +160,12 @@ def default_val_transforms(self, image_size: Tuple[int, int]) -> Dict[str, Calla "post_tensor_transform": T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), } + def default_test_transforms(self, image_size: Tuple[int, int]) -> Dict[str, Callable]: + return self.default_val_transforms(image_size) + + def default_predict_transforms(self, image_size: Tuple[int, int]) -> Dict[str, Callable]: + return self.default_val_transforms(image_size) + def _resolve_transforms( self, train_transform: Optional[Union[str, Dict]] = 'default', @@ -176,10 +182,10 @@ def _resolve_transforms( val_transform = self.default_val_transforms(image_size) if not test_transform or test_transform == 'default': - test_transform = self.default_val_transforms(image_size) + test_transform = self.default_test_transforms(image_size) if not predict_transform or predict_transform == 'default': - predict_transform = self.default_val_transforms(image_size) + predict_transform = self.default_predict_transforms(image_size) return ( train_transform, diff --git a/flash/vision/style_transfer/__init__.py b/flash/vision/style_transfer/__init__.py index aa70fa9459..7a7663587b 100644 --- a/flash/vision/style_transfer/__init__.py +++ b/flash/vision/style_transfer/__init__.py @@ -1 +1,2 @@ +from flash.vision.style_transfer.data import StyleTransferPreprocess, StyleTransferData from flash.vision.style_transfer.model import StyleTransfer diff --git a/flash/vision/style_transfer/_utils.py b/flash/vision/style_transfer/_utils.py new file mode 100644 index 0000000000..65bf503adb --- /dev/null +++ b/flash/vision/style_transfer/_utils.py @@ -0,0 +1,7 @@ +from typing import NoReturn + +__all__ = ["raise_not_supported"] + + +def raise_not_supported(phase: str) -> NoReturn: + raise RuntimeError(f"Style transfer does not support a {phase} phase.") diff --git a/flash/vision/style_transfer/data.py b/flash/vision/style_transfer/data.py new file mode 100644 index 0000000000..a507c27b90 --- /dev/null +++ b/flash/vision/style_transfer/data.py @@ -0,0 +1,95 @@ +import pathlib +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import torchvision +from torch import nn +from torchvision import transforms + +from flash.data.process import Preprocess +from flash.vision.classification import ImageClassificationData, ImageClassificationPreprocess + +from ._utils import raise_not_supported + +__all__ = ["StyleTransferPreprocess", "StyleTransferData"] + + +class OptionalGrayscaleToFakeGrayscale(nn.Module): + def forward(self, input): + num_channels = input.size()[0] + if num_channels > 1: + return input + + return input.repeat(3, 1, 1) + + +class StyleTransferPreprocess(ImageClassificationPreprocess): + def __init__( + self, + train_transform: Optional[Union[Dict[str, Callable]]] = None, + predict_transform: Optional[Union[Dict[str, Callable]]] = None, + image_size: Union[int, Tuple[int, int]] = 256, + ): + if isinstance(image_size, int): + image_size = (image_size, image_size) + super().__init__( + train_transform=train_transform, + predict_transform=predict_transform, + image_size=image_size, + ) + + def default_train_transforms(self, image_size: Tuple[int, int]) -> Dict[str, Callable]: + return dict( + to_tensor_transform=torchvision.transforms.ToTensor(), + # Some datasets, such as the one used in flash_examples/finetuning/style_transfer.py contain some rogue + # grayscale images. To not interrupt the training flow, we simply convert them to fake grayscale, by + # repeating the values for three channels, mimicking an RGB image. + post_tensor_transform=OptionalGrayscaleToFakeGrayscale(), + per_batch_transform_on_device=nn.Sequential( + transforms.Resize(min(image_size)), + transforms.CenterCrop(image_size), + ), + ) + + def default_val_transforms(self, image_size: Any) -> Dict[str, Callable]: + # Style transfer doesn't support a validation phase, so we return nothing here + return {} + + def default_test_transforms(self, image_size: Any) -> Dict[str, Callable]: + # Style transfer doesn't support a test phase, so we return nothing here + return {} + + def default_predict_transforms(self, image_size: Tuple[int, int]) -> Dict[str, Callable]: + return dict( + to_tensor_transform=torchvision.transforms.ToTensor(), + per_batch_transform_on_device=nn.Sequential( + transforms.Resize(min(image_size)), + ), + ) + + +class StyleTransferData(ImageClassificationData): + preprocess_cls = StyleTransferPreprocess + + @classmethod + def from_folders( + cls, + train_folder: Optional[Union[str, pathlib.Path]] = None, + predict_folder: Optional[Union[str, pathlib.Path]] = None, + train_transform: Optional[Union[str, Dict]] = "default", + predict_transform: Optional[Union[str, Dict]] = "default", + preprocess: Optional[Preprocess] = None, + **kwargs: Any, + ) -> "StyleTransferData": + if any(param in kwargs for param in ("val_folder", "val_transform")): + raise_not_supported("validation") + if any(param in kwargs for param in ("test_folder", "test_transform")): + raise_not_supported("test") + + preprocess = preprocess or cls.preprocess_cls(train_transform, predict_transform) + + return cls.from_load_data_inputs( + train_load_data_input=train_folder, + predict_load_data_input=predict_folder, + preprocess=preprocess, + **kwargs, + ) From 4ffc734c3c39c76cafb6f4285790858d443cfcd3 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 11 May 2021 11:19:21 +0200 Subject: [PATCH 08/60] add internal pre / post-processing --- flash/vision/style_transfer/model.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/flash/vision/style_transfer/model.py b/flash/vision/style_transfer/model.py index 9cfa177865..a3b484137a 100644 --- a/flash/vision/style_transfer/model.py +++ b/flash/vision/style_transfer/model.py @@ -77,6 +77,16 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return output + input +class FloatToUint8Range(nn.Module): + def forward(self, input): + return input * 255.0 + + +class Uint8ToFloatRange(nn.Module): + def forward(self, input): + return input / 255.0 + + class Transformer(nn.Module): def __init__(self) -> None: @@ -97,8 +107,13 @@ def __init__(self) -> None: Conv(32, 3, kernel_size=9, norm=False, activation=False), ) - def forward(self, input: torch.Tensor) -> torch.Tensor: - return self.decoder(self.encoder(input)) + self.preprocessor = FloatToUint8Range() + self.postprocessor = Uint8ToFloatRange() + + def forward(self, input): + input = self.preprocessor(input) + output = self.decoder(self.encoder(input)) + return self.postprocessor(output) class StyleTransfer(Task): From 9bb45bcec5c4203627c48d44e99ac9bf556a3e6a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 11 May 2021 11:19:58 +0200 Subject: [PATCH 09/60] bail out if val / test step is performed --- flash/vision/style_transfer/model.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/flash/vision/style_transfer/model.py b/flash/vision/style_transfer/model.py index a3b484137a..1c6583867c 100644 --- a/flash/vision/style_transfer/model.py +++ b/flash/vision/style_transfer/model.py @@ -1,7 +1,8 @@ -from typing import Any, Dict, Mapping, Optional, Sequence, Type, Union +from typing import Any, Dict, Mapping, NoReturn, Optional, Sequence, Type, Union import torch import torchmetrics +from _utils import raise_not_supported from pystiche import enc, loss, ops from pystiche.image import read_image from torch import nn @@ -202,3 +203,9 @@ def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor: def forward(self, content_image: torch.Tensor) -> torch.Tensor: self.perceptual_loss.set_content_image(content_image) return self.model(content_image) + + def validation_step(self, batch: Any, batch_idx: int) -> NoReturn: + raise_not_supported("validation") + + def test_step(self, batch: Any, batch_idx: int) -> NoReturn: + raise_not_supported("test") \ No newline at end of file From e82a94ca7339c2504f29d8c7abf30ae441a30b36 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 11 May 2021 11:21:36 +0200 Subject: [PATCH 10/60] update example --- flash_examples/predict/style_transfer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flash_examples/predict/style_transfer.py b/flash_examples/predict/style_transfer.py index d57ffb7438..b2ee398834 100644 --- a/flash_examples/predict/style_transfer.py +++ b/flash_examples/predict/style_transfer.py @@ -3,7 +3,7 @@ import flash from flash.data.utils import download_data from flash.utils.imports import _PYSTICHE_AVAILABLE -from flash.vision.style_transfer import StyleTransfer +from flash.vision.style_transfer import StyleTransfer, StyleTransferData if _PYSTICHE_AVAILABLE: import pystiche.demo @@ -13,9 +13,9 @@ download_data("http://images.cocodataset.org/zips/train2014.zip", "data") -data_module = ImageUnsupervisedData.from_folder("data", batch_size=4) +data_module = StyleTransferData.from_folders(train_folder="data", batch_size=4) -style_image = pystiche.demo.images()["paint"].read(size=256, edge="long") +style_image = pystiche.demo.images()["paint"].read(size=256) model = StyleTransfer(style_image) From d939fad7e49b1402c821b7a13ddbccc11460a7c0 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 11 May 2021 11:22:25 +0200 Subject: [PATCH 11/60] move example from predict to finetuning --- flash_examples/{predict => finetuning}/style_transfer.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename flash_examples/{predict => finetuning}/style_transfer.py (100%) diff --git a/flash_examples/predict/style_transfer.py b/flash_examples/finetuning/style_transfer.py similarity index 100% rename from flash_examples/predict/style_transfer.py rename to flash_examples/finetuning/style_transfer.py From 5b10dbbbfaaa3c7dbca9b25a6ff1e7970e21104b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 11 May 2021 11:23:14 +0200 Subject: [PATCH 12/60] remove metrics from task --- flash/vision/style_transfer/model.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/flash/vision/style_transfer/model.py b/flash/vision/style_transfer/model.py index 1c6583867c..998e6c7ec6 100644 --- a/flash/vision/style_transfer/model.py +++ b/flash/vision/style_transfer/model.py @@ -130,7 +130,6 @@ def __init__( optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, - metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, learning_rate: float = 1e-3, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, ): @@ -158,7 +157,6 @@ def __init__( optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, - metrics=metrics, learning_rate=learning_rate, serializer=serializer, ) From 2e6901aa9fe066c9dfa06b4c6b019e28104269f5 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 11 May 2021 11:26:02 +0200 Subject: [PATCH 13/60] flake8 --- flash/vision/style_transfer/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash/vision/style_transfer/model.py b/flash/vision/style_transfer/model.py index 998e6c7ec6..cf52bc7585 100644 --- a/flash/vision/style_transfer/model.py +++ b/flash/vision/style_transfer/model.py @@ -206,4 +206,4 @@ def validation_step(self, batch: Any, batch_idx: int) -> NoReturn: raise_not_supported("validation") def test_step(self, batch: Any, batch_idx: int) -> NoReturn: - raise_not_supported("test") \ No newline at end of file + raise_not_supported("test") From eeed004513f0b0e76a88f781c6a5235454b62fb3 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 11 May 2021 11:35:08 +0200 Subject: [PATCH 14/60] remove unused imports --- flash/vision/style_transfer/model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flash/vision/style_transfer/model.py b/flash/vision/style_transfer/model.py index 157fbb9fd4..f2180a88e0 100644 --- a/flash/vision/style_transfer/model.py +++ b/flash/vision/style_transfer/model.py @@ -1,7 +1,6 @@ -from typing import Any, Dict, Mapping, NoReturn, Optional, Sequence, Type, Union +from typing import Any, Dict, Mapping, NoReturn, Optional, Type, Union import torch -import torchmetrics from _utils import raise_not_supported from pystiche import enc, loss, ops from pystiche.image import read_image From 7d38a5b657ef10f985c91cc4afc348bf47da6e07 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 May 2021 09:36:03 +0000 Subject: [PATCH 15/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flash/vision/__init__.py | 2 +- flash/vision/style_transfer/__init__.py | 2 +- flash/vision/style_transfer/data.py | 6 +++--- flash/vision/style_transfer/model.py | 9 +++++++++ 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/flash/vision/__init__.py b/flash/vision/__init__.py index ba4f9b8f10..937c6888a7 100644 --- a/flash/vision/__init__.py +++ b/flash/vision/__init__.py @@ -3,4 +3,4 @@ from flash.vision.detection import ObjectDetectionData, ObjectDetector from flash.vision.embedding import ImageEmbedder from flash.vision.segmentation import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationPreprocess -from flash.vision.style_transfer import StyleTransferPreprocess, StyleTransferData, StyleTransfer +from flash.vision.style_transfer import StyleTransfer, StyleTransferData, StyleTransferPreprocess diff --git a/flash/vision/style_transfer/__init__.py b/flash/vision/style_transfer/__init__.py index 7a7663587b..3cac970788 100644 --- a/flash/vision/style_transfer/__init__.py +++ b/flash/vision/style_transfer/__init__.py @@ -1,2 +1,2 @@ -from flash.vision.style_transfer.data import StyleTransferPreprocess, StyleTransferData +from flash.vision.style_transfer.data import StyleTransferData, StyleTransferPreprocess from flash.vision.style_transfer.model import StyleTransfer diff --git a/flash/vision/style_transfer/data.py b/flash/vision/style_transfer/data.py index 82be7a48d8..e717c01663 100644 --- a/flash/vision/style_transfer/data.py +++ b/flash/vision/style_transfer/data.py @@ -14,6 +14,7 @@ class OptionalGrayscaleToFakeGrayscale(nn.Module): + def forward(self, input): num_channels = input.size()[0] if num_channels > 1: @@ -23,6 +24,7 @@ def forward(self, input): class StyleTransferPreprocess(ImageClassificationPreprocess): + def __init__( self, train_transform: Optional[Union[Dict[str, Callable]]] = None, @@ -65,9 +67,7 @@ def default_test_transforms(self) -> None: def default_predict_transforms(self) -> Dict[str, Callable]: return dict( to_tensor_transform=torchvision.transforms.ToTensor(), - per_batch_transform_on_device=nn.Sequential( - transforms.Resize(min(self.image_size)), - ), + per_batch_transform_on_device=nn.Sequential(transforms.Resize(min(self.image_size)), ), ) diff --git a/flash/vision/style_transfer/model.py b/flash/vision/style_transfer/model.py index f2180a88e0..dd6e07c101 100644 --- a/flash/vision/style_transfer/model.py +++ b/flash/vision/style_transfer/model.py @@ -15,6 +15,7 @@ class Interpolate(nn.Module): + def __init__(self, scale_factor: float = 1.0, mode: str = "nearest") -> None: super().__init__() self.scale_factor = scale_factor @@ -31,6 +32,7 @@ def extra_repr(self) -> str: class Conv(nn.Module): + def __init__( self, in_channels: int, @@ -64,6 +66,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class Residual(nn.Module): + def __init__(self, channels: int) -> None: super().__init__() self.conv1 = Conv(channels, channels, kernel_size=3) @@ -75,16 +78,19 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class FloatToUint8Range(nn.Module): + def forward(self, input): return input * 255.0 class Uint8ToFloatRange(nn.Module): + def forward(self, input): return input / 255.0 class Transformer(nn.Module): + def __init__(self) -> None: super().__init__() self.encoder = nn.Sequential( @@ -113,6 +119,7 @@ def forward(self, input): class StyleTransfer(Task): + def __init__( self, style_image: Union[str, torch.Tensor], @@ -171,7 +178,9 @@ def default_content_loss( def default_style_loss( self, multi_layer_encoder: Optional[enc.MultiLayerEncoder] = None ) -> ops.MultiLayerEncodingOperator: + class GramOperator(ops.GramOperator): + def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor: repr = super().enc_to_repr(enc) num_channels = repr.size()[1] From e844db4ebba1bbdaaa210658a489521c81284b74 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 11 May 2021 11:40:15 +0200 Subject: [PATCH 16/60] remove grayscale handling --- flash/vision/style_transfer/data.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/flash/vision/style_transfer/data.py b/flash/vision/style_transfer/data.py index e717c01663..7fa342e0c4 100644 --- a/flash/vision/style_transfer/data.py +++ b/flash/vision/style_transfer/data.py @@ -13,16 +13,6 @@ __all__ = ["StyleTransferPreprocess", "StyleTransferData"] -class OptionalGrayscaleToFakeGrayscale(nn.Module): - - def forward(self, input): - num_channels = input.size()[0] - if num_channels > 1: - return input - - return input.repeat(3, 1, 1) - - class StyleTransferPreprocess(ImageClassificationPreprocess): def __init__( @@ -43,10 +33,6 @@ def __init__( def default_train_transforms(self) -> Dict[str, Callable]: return dict( to_tensor_transform=torchvision.transforms.ToTensor(), - # Some datasets, such as the one used in flash_examples/finetuning/style_transfer.py contain some rogue - # grayscale images. To not interrupt the training flow, we simply convert them to fake grayscale, by - # repeating the values for three channels, mimicking an RGB image. - post_tensor_transform=OptionalGrayscaleToFakeGrayscale(), per_batch_transform_on_device=nn.Sequential( transforms.Resize(min(self.image_size)), transforms.CenterCrop(self.image_size), From a932f86e44cbfc140b995f8cdb1494f8e957be4e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 11 May 2021 21:04:20 +0200 Subject: [PATCH 17/60] address review comments and small fixes --- flash/vision/style_transfer/data.py | 32 +++++++++++++-------- flash/vision/style_transfer/model.py | 29 ++++++++++--------- flash_examples/finetuning/style_transfer.py | 2 ++ 3 files changed, 38 insertions(+), 25 deletions(-) diff --git a/flash/vision/style_transfer/data.py b/flash/vision/style_transfer/data.py index 7fa342e0c4..e9f66e36b4 100644 --- a/flash/vision/style_transfer/data.py +++ b/flash/vision/style_transfer/data.py @@ -5,7 +5,9 @@ from torch import nn from torchvision import transforms +from flash.data.data_source import DefaultDataKeys, DefaultDataSources from flash.data.process import Preprocess +from flash.data.transforms import ApplyToKeys from flash.vision.classification import ImageClassificationData, ImageClassificationPreprocess from ._utils import raise_not_supported @@ -14,7 +16,6 @@ class StyleTransferPreprocess(ImageClassificationPreprocess): - def __init__( self, train_transform: Optional[Union[Dict[str, Callable]]] = None, @@ -29,13 +30,18 @@ def __init__( image_size=image_size, ) + def _apply_to_input(self, transform: Callable) -> ApplyToKeys: + return ApplyToKeys(DefaultDataKeys.INPUT, transform) + @property def default_train_transforms(self) -> Dict[str, Callable]: return dict( - to_tensor_transform=torchvision.transforms.ToTensor(), - per_batch_transform_on_device=nn.Sequential( - transforms.Resize(min(self.image_size)), - transforms.CenterCrop(self.image_size), + to_tensor_transform=self._apply_to_input(torchvision.transforms.ToTensor()), + per_batch_transform_on_device=self._apply_to_input( + nn.Sequential( + transforms.Resize(min(self.image_size)), + transforms.CenterCrop(self.image_size), + ) ), ) @@ -51,9 +57,10 @@ def default_test_transforms(self) -> None: @property def default_predict_transforms(self) -> Dict[str, Callable]: + return dict( - to_tensor_transform=torchvision.transforms.ToTensor(), - per_batch_transform_on_device=nn.Sequential(transforms.Resize(min(self.image_size)), ), + to_tensor_transform=self._apply_to_input(torchvision.transforms.ToTensor()), + per_batch_transform_on_device=self._apply_to_input(nn.Sequential(transforms.Resize(min(self.image_size)))), ) @@ -65,8 +72,8 @@ def from_folders( cls, train_folder: Optional[Union[str, pathlib.Path]] = None, predict_folder: Optional[Union[str, pathlib.Path]] = None, - train_transform: Optional[Union[str, Dict]] = "default", - predict_transform: Optional[Union[str, Dict]] = "default", + train_transform: Optional[Union[str, Dict]] = None, + predict_transform: Optional[Union[str, Dict]] = None, preprocess: Optional[Preprocess] = None, **kwargs: Any, ) -> "StyleTransferData": @@ -77,9 +84,10 @@ def from_folders( preprocess = preprocess or cls.preprocess_cls(train_transform, predict_transform) - return cls.from_load_data_inputs( - train_load_data_input=train_folder, - predict_load_data_input=predict_folder, + return cls.from_data_source( + DefaultDataSources.PATHS, + train_data=train_folder, + predict_data=predict_folder, preprocess=preprocess, **kwargs, ) diff --git a/flash/vision/style_transfer/model.py b/flash/vision/style_transfer/model.py index dd6e07c101..c1d7f1a252 100644 --- a/flash/vision/style_transfer/model.py +++ b/flash/vision/style_transfer/model.py @@ -1,7 +1,6 @@ from typing import Any, Dict, Mapping, NoReturn, Optional, Type, Union import torch -from _utils import raise_not_supported from pystiche import enc, loss, ops from pystiche.image import read_image from torch import nn @@ -9,13 +8,15 @@ from torch.optim.lr_scheduler import _LRScheduler from flash.core import Task +from flash.data.data_source import DefaultDataKeys from flash.data.process import Serializer +from ._utils import raise_not_supported + __all__ = ["StyleTransfer"] class Interpolate(nn.Module): - def __init__(self, scale_factor: float = 1.0, mode: str = "nearest") -> None: super().__init__() self.scale_factor = scale_factor @@ -32,7 +33,6 @@ def extra_repr(self) -> str: class Conv(nn.Module): - def __init__( self, in_channels: int, @@ -66,7 +66,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class Residual(nn.Module): - def __init__(self, channels: int) -> None: super().__init__() self.conv1 = Conv(channels, channels, kernel_size=3) @@ -78,19 +77,16 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class FloatToUint8Range(nn.Module): - def forward(self, input): return input * 255.0 class Uint8ToFloatRange(nn.Module): - def forward(self, input): return input / 255.0 class Transformer(nn.Module): - def __init__(self) -> None: super().__init__() self.encoder = nn.Sequential( @@ -119,7 +115,6 @@ def forward(self, input): class StyleTransfer(Task): - def __init__( self, style_image: Union[str, torch.Tensor], @@ -137,6 +132,10 @@ def __init__( if isinstance(style_image, str): style_image = read_image(style_image) + if model is None: + # TODO: import this from pystiche + model = Transformer() + if multi_layer_encoder is None: multi_layer_encoder = self.default_multi_layer_encoder() @@ -146,14 +145,13 @@ def __init__( if style_loss is None: style_loss = self.default_style_loss(multi_layer_encoder) - self.perceptual_loss = loss.PerceptualLoss(content_loss, style_loss) - self.perceptual_loss.set_style_image(style_image) + perceptual_loss = loss.PerceptualLoss(content_loss, style_loss) self.save_hyperparameters() super().__init__( model=model, - loss_fn=self.perceptual_loss, + loss_fn=perceptual_loss, optimizer=optimizer, optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, @@ -162,6 +160,10 @@ def __init__( serializer=serializer, ) + # can't assign modules before super init call + self.perceptual_loss = perceptual_loss + self.perceptual_loss.set_style_image(style_image) + def default_multi_layer_encoder(self) -> enc.MultiLayerEncoder: return enc.vgg16_multi_layer_encoder() @@ -178,9 +180,7 @@ def default_content_loss( def default_style_loss( self, multi_layer_encoder: Optional[enc.MultiLayerEncoder] = None ) -> ops.MultiLayerEncodingOperator: - class GramOperator(ops.GramOperator): - def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor: repr = super().enc_to_repr(enc) num_channels = repr.size()[1] @@ -203,6 +203,9 @@ def forward(self, content_image: torch.Tensor) -> torch.Tensor: self.perceptual_loss.set_content_image(content_image) return self.model(content_image) + def training_step(self, batch: Any, batch_idx: int) -> Any: + return super().training_step(batch[DefaultDataKeys.INPUT], batch_idx) + def validation_step(self, batch: Any, batch_idx: int) -> NoReturn: raise_not_supported("validation") diff --git a/flash_examples/finetuning/style_transfer.py b/flash_examples/finetuning/style_transfer.py index b2ee398834..56ff5237f8 100644 --- a/flash_examples/finetuning/style_transfer.py +++ b/flash_examples/finetuning/style_transfer.py @@ -21,3 +21,5 @@ trainer = flash.Trainer(max_epochs=2) trainer.fit(model, data_module) + +trainer.save_checkpoint("style_transfer_model.pt") From ba091ccadee8b1a28f743babf1cd8354e8778604 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 May 2021 19:04:59 +0000 Subject: [PATCH 18/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flash/vision/style_transfer/data.py | 1 + flash/vision/style_transfer/model.py | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/flash/vision/style_transfer/data.py b/flash/vision/style_transfer/data.py index e9f66e36b4..bc672e6a17 100644 --- a/flash/vision/style_transfer/data.py +++ b/flash/vision/style_transfer/data.py @@ -16,6 +16,7 @@ class StyleTransferPreprocess(ImageClassificationPreprocess): + def __init__( self, train_transform: Optional[Union[Dict[str, Callable]]] = None, diff --git a/flash/vision/style_transfer/model.py b/flash/vision/style_transfer/model.py index c1d7f1a252..fc292f93fe 100644 --- a/flash/vision/style_transfer/model.py +++ b/flash/vision/style_transfer/model.py @@ -17,6 +17,7 @@ class Interpolate(nn.Module): + def __init__(self, scale_factor: float = 1.0, mode: str = "nearest") -> None: super().__init__() self.scale_factor = scale_factor @@ -33,6 +34,7 @@ def extra_repr(self) -> str: class Conv(nn.Module): + def __init__( self, in_channels: int, @@ -66,6 +68,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class Residual(nn.Module): + def __init__(self, channels: int) -> None: super().__init__() self.conv1 = Conv(channels, channels, kernel_size=3) @@ -77,16 +80,19 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class FloatToUint8Range(nn.Module): + def forward(self, input): return input * 255.0 class Uint8ToFloatRange(nn.Module): + def forward(self, input): return input / 255.0 class Transformer(nn.Module): + def __init__(self) -> None: super().__init__() self.encoder = nn.Sequential( @@ -115,6 +121,7 @@ def forward(self, input): class StyleTransfer(Task): + def __init__( self, style_image: Union[str, torch.Tensor], @@ -180,7 +187,9 @@ def default_content_loss( def default_style_loss( self, multi_layer_encoder: Optional[enc.MultiLayerEncoder] = None ) -> ops.MultiLayerEncodingOperator: + class GramOperator(ops.GramOperator): + def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor: repr = super().enc_to_repr(enc) num_channels = repr.size()[1] From e399d5095a323d2a419f23051cb305394c4ef499 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 12 May 2021 10:24:10 +0200 Subject: [PATCH 19/60] streamline apply_to_input --- flash/vision/style_transfer/data.py | 30 +++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/flash/vision/style_transfer/data.py b/flash/vision/style_transfer/data.py index bc672e6a17..1326b03aeb 100644 --- a/flash/vision/style_transfer/data.py +++ b/flash/vision/style_transfer/data.py @@ -1,3 +1,4 @@ +import functools import pathlib from typing import Any, Callable, Dict, Optional, Tuple, Union @@ -15,8 +16,16 @@ __all__ = ["StyleTransferPreprocess", "StyleTransferData"] -class StyleTransferPreprocess(ImageClassificationPreprocess): +def _apply_to_input(default_transforms_fn) -> Callable[..., Dict[str, ApplyToKeys]]: + @functools.wraps(default_transforms_fn) + def wrapper(*args: Any, **kwargs: Any) -> Dict[str, ApplyToKeys]: + default_transforms = default_transforms_fn(*args, **kwargs) + return {hook: ApplyToKeys(DefaultDataKeys.INPUT, transform) for hook, transform in default_transforms.items()} + + return wrapper + +class StyleTransferPreprocess(ImageClassificationPreprocess): def __init__( self, train_transform: Optional[Union[Dict[str, Callable]]] = None, @@ -31,18 +40,14 @@ def __init__( image_size=image_size, ) - def _apply_to_input(self, transform: Callable) -> ApplyToKeys: - return ApplyToKeys(DefaultDataKeys.INPUT, transform) - @property + @_apply_to_input def default_train_transforms(self) -> Dict[str, Callable]: return dict( - to_tensor_transform=self._apply_to_input(torchvision.transforms.ToTensor()), - per_batch_transform_on_device=self._apply_to_input( - nn.Sequential( - transforms.Resize(min(self.image_size)), - transforms.CenterCrop(self.image_size), - ) + to_tensor_transform=torchvision.transforms.ToTensor(), + per_sample_transform_on_device=nn.Sequential( + transforms.Resize(min(self.image_size)), + transforms.CenterCrop(self.image_size), ), ) @@ -57,11 +62,12 @@ def default_test_transforms(self) -> None: return None @property + @_apply_to_input def default_predict_transforms(self) -> Dict[str, Callable]: return dict( - to_tensor_transform=self._apply_to_input(torchvision.transforms.ToTensor()), - per_batch_transform_on_device=self._apply_to_input(nn.Sequential(transforms.Resize(min(self.image_size)))), + to_tensor_transform=torchvision.transforms.ToTensor(), + per_sample_transform_on_device=transforms.Resize(min(self.image_size)), ) From dc80a2178b39d55fd7e785fb25fbe208cb49e63c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 12 May 2021 10:25:07 +0200 Subject: [PATCH 20/60] fix hyper parameters saving --- flash/vision/style_transfer/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash/vision/style_transfer/model.py b/flash/vision/style_transfer/model.py index fc292f93fe..33aefb43ad 100644 --- a/flash/vision/style_transfer/model.py +++ b/flash/vision/style_transfer/model.py @@ -136,6 +136,8 @@ def __init__( learning_rate: float = 1e-3, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, ): + self.save_hyperparameters(ignore="style_image") + if isinstance(style_image, str): style_image = read_image(style_image) @@ -154,8 +156,6 @@ def __init__( perceptual_loss = loss.PerceptualLoss(content_loss, style_loss) - self.save_hyperparameters() - super().__init__( model=model, loss_fn=perceptual_loss, From 9f7fd4136a18893835a378e7da46ea2f02c7216a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 12 May 2021 10:25:33 +0200 Subject: [PATCH 21/60] implement custom step --- flash/vision/style_transfer/model.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/flash/vision/style_transfer/model.py b/flash/vision/style_transfer/model.py index 33aefb43ad..5b531136bb 100644 --- a/flash/vision/style_transfer/model.py +++ b/flash/vision/style_transfer/model.py @@ -208,9 +208,15 @@ def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor: score_weight=style_weight, ) - def forward(self, content_image: torch.Tensor) -> torch.Tensor: - self.perceptual_loss.set_content_image(content_image) - return self.model(content_image) + def step(self, batch: Any, batch_idx: int) -> Any: + input_image = batch + self.perceptual_loss.set_content_image(input_image) + + output_image = self(batch) + loss = self.perceptual_loss(output_image).total() + + logs = dict(perceptual_loss=loss) + return dict(loss=loss, logs=logs) def training_step(self, batch: Any, batch_idx: int) -> Any: return super().training_step(batch[DefaultDataKeys.INPUT], batch_idx) From eabf49b505ea1c555acb76c8d58764f2a6c6c14b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 12 May 2021 10:25:52 +0200 Subject: [PATCH 22/60] cleanup --- flash/vision/style_transfer/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flash/vision/style_transfer/model.py b/flash/vision/style_transfer/model.py index 5b531136bb..1246638b23 100644 --- a/flash/vision/style_transfer/model.py +++ b/flash/vision/style_transfer/model.py @@ -167,7 +167,6 @@ def __init__( serializer=serializer, ) - # can't assign modules before super init call self.perceptual_loss = perceptual_loss self.perceptual_loss.set_style_image(style_image) From 361074b8f7b2be56cf17263c93124a3880de763e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 May 2021 08:26:54 +0000 Subject: [PATCH 23/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flash/vision/style_transfer/data.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flash/vision/style_transfer/data.py b/flash/vision/style_transfer/data.py index 1326b03aeb..67fbd7a416 100644 --- a/flash/vision/style_transfer/data.py +++ b/flash/vision/style_transfer/data.py @@ -17,6 +17,7 @@ def _apply_to_input(default_transforms_fn) -> Callable[..., Dict[str, ApplyToKeys]]: + @functools.wraps(default_transforms_fn) def wrapper(*args: Any, **kwargs: Any) -> Dict[str, ApplyToKeys]: default_transforms = default_transforms_fn(*args, **kwargs) @@ -26,6 +27,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Dict[str, ApplyToKeys]: class StyleTransferPreprocess(ImageClassificationPreprocess): + def __init__( self, train_transform: Optional[Union[Dict[str, Callable]]] = None, From 464a26c79d44a3da1b6aec1e1379671d212068cd Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 14 May 2021 20:44:58 +0200 Subject: [PATCH 24/60] add explanation to not supported phases --- flash/vision/style_transfer/_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/flash/vision/style_transfer/_utils.py b/flash/vision/style_transfer/_utils.py index 65bf503adb..dc34ba6d4f 100644 --- a/flash/vision/style_transfer/_utils.py +++ b/flash/vision/style_transfer/_utils.py @@ -4,4 +4,7 @@ def raise_not_supported(phase: str) -> NoReturn: - raise RuntimeError(f"Style transfer does not support a {phase} phase.") + raise RuntimeError( + f"Style transfer does not support a {phase} phase, " + f"since there is no metric to objectively determine the quality of a stylization." + ) From d78989e8e483e2e270953d7f5e84d17590129751 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 14 May 2021 20:46:30 +0200 Subject: [PATCH 25/60] temporarily use unreleased pystiche version --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index af3bde36c7..2cb711f420 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,4 +19,5 @@ pycocotools>=2.0.2 ; python_version >= "3.7" kornia>=0.5.1 pytorchvideo matplotlib # used by the visualisation callback -pystiche>=0.7.1 +git+https://github.com/pystiche/pystiche@releases/v0.7.2 +# pystiche>=0.7.2 From 1b2e6e30a2a15bd7637dc182d87d886ec68d80a2 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 14 May 2021 20:46:48 +0200 Subject: [PATCH 26/60] add missing transforms in preprocess --- flash/vision/style_transfer/data.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/flash/vision/style_transfer/data.py b/flash/vision/style_transfer/data.py index 67fbd7a416..7de6c730d1 100644 --- a/flash/vision/style_transfer/data.py +++ b/flash/vision/style_transfer/data.py @@ -31,13 +31,22 @@ class StyleTransferPreprocess(ImageClassificationPreprocess): def __init__( self, train_transform: Optional[Union[Dict[str, Callable]]] = None, + val_transform: Optional[Union[Dict[str, Callable]]] = None, + test_transform: Optional[Union[Dict[str, Callable]]] = None, predict_transform: Optional[Union[Dict[str, Callable]]] = None, image_size: Union[int, Tuple[int, int]] = 256, ): + if val_transform: + raise_not_supported("validation") + if test_transform: + raise_not_supported("test") + if isinstance(image_size, int): image_size = (image_size, image_size) super().__init__( train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, predict_transform=predict_transform, image_size=image_size, ) From 0feaf7a05cadb1aff5219404a901abfce1c47f95 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 14 May 2021 20:48:51 +0200 Subject: [PATCH 27/60] introduce multi layer encoders as backbones --- flash/vision/style_transfer/backbone.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 flash/vision/style_transfer/backbone.py diff --git a/flash/vision/style_transfer/backbone.py b/flash/vision/style_transfer/backbone.py new file mode 100644 index 0000000000..1967f14ee0 --- /dev/null +++ b/flash/vision/style_transfer/backbone.py @@ -0,0 +1,23 @@ +import re + +from pystiche import enc + +from flash.core.registry import FlashRegistry + +__all__ = ["STYLE_TRANSFER_BACKBONES"] + +MLE_FN_PATTERN = re.compile(r"^(?P\w+?)_multi_layer_encoder$") + +STYLE_TRANSFER_BACKBONES = FlashRegistry("backbones") + +for mle_fn in dir(enc): + match = MLE_FN_PATTERN.match(mle_fn) + if not match: + continue + + STYLE_TRANSFER_BACKBONES( + fn=lambda: (getattr(enc, mle_fn)(), None), + name=match.group("name"), + namespace="vision/style_transfer", + package="pystiche", + ) From c41e38c3a11c778c6a7412e6bb48c9bdc706ea4f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 14 May 2021 20:49:28 +0200 Subject: [PATCH 28/60] refactor task --- flash/vision/style_transfer/model.py | 210 +++++++-------------------- 1 file changed, 51 insertions(+), 159 deletions(-) diff --git a/flash/vision/style_transfer/model.py b/flash/vision/style_transfer/model.py index 1246638b23..a27bb8c94b 100644 --- a/flash/vision/style_transfer/model.py +++ b/flash/vision/style_transfer/model.py @@ -1,134 +1,35 @@ -from typing import Any, Dict, Mapping, NoReturn, Optional, Type, Union +from typing import Any, cast, Dict, Mapping, NoReturn, Optional, Sequence, Type, Union +import pystiche.demo import torch from pystiche import enc, loss, ops from pystiche.image import read_image from torch import nn -from torch.nn.functional import interpolate from torch.optim.lr_scheduler import _LRScheduler from flash.core import Task +from flash.core.registry import FlashRegistry from flash.data.data_source import DefaultDataKeys from flash.data.process import Serializer +from flash.vision.style_transfer.backbone import STYLE_TRANSFER_BACKBONES from ._utils import raise_not_supported __all__ = ["StyleTransfer"] -class Interpolate(nn.Module): - - def __init__(self, scale_factor: float = 1.0, mode: str = "nearest") -> None: - super().__init__() - self.scale_factor = scale_factor - self.mode = mode - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return interpolate(input, scale_factor=self.scale_factor, mode=self.mode) - - def extra_repr(self) -> str: - extras = [f"scale_factor={self.scale_factor}"] - if self.mode != "nearest": - extras.append(f"mode={self.mode}") - return ", ".join(extras) - - -class Conv(nn.Module): - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - *, - stride: int = 1, - upsample: bool = False, - norm: bool = True, - activation: bool = True, - ): - super().__init__() - self.upsample = Interpolate(scale_factor=stride) if upsample else None - self.pad = nn.ReflectionPad2d(kernel_size // 2) - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1 if upsample else stride) - self.norm = nn.InstanceNorm2d(out_channels, affine=True) if norm else None - self.activation = nn.ReLU() if activation else None - - def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.upsample: - input = self.upsample(input) - - output = self.conv(self.pad(input)) - - if self.norm: - output = self.norm(output) - if self.activation: - output = self.activation(output) - - return output - - -class Residual(nn.Module): - - def __init__(self, channels: int) -> None: - super().__init__() - self.conv1 = Conv(channels, channels, kernel_size=3) - self.conv2 = Conv(channels, channels, kernel_size=3, activation=False) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - output = self.conv2(self.conv1(input)) - return output + input - - -class FloatToUint8Range(nn.Module): - - def forward(self, input): - return input * 255.0 - - -class Uint8ToFloatRange(nn.Module): - - def forward(self, input): - return input / 255.0 - - -class Transformer(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.encoder = nn.Sequential( - Conv(3, 32, kernel_size=9), - Conv(32, 64, kernel_size=3, stride=2), - Conv(64, 128, kernel_size=3, stride=2), - Residual(128), - Residual(128), - Residual(128), - Residual(128), - Residual(128), - ) - self.decoder = nn.Sequential( - Conv(128, 64, kernel_size=3, stride=2, upsample=True), - Conv(64, 32, kernel_size=3, stride=2, upsample=True), - Conv(32, 3, kernel_size=9, norm=False, activation=False), - ) - - self.preprocessor = FloatToUint8Range() - self.postprocessor = Uint8ToFloatRange() - - def forward(self, input): - input = self.preprocessor(input) - output = self.decoder(self.encoder(input)) - return self.postprocessor(output) - - class StyleTransfer(Task): + backbones: FlashRegistry = STYLE_TRANSFER_BACKBONES def __init__( self, - style_image: Union[str, torch.Tensor], + style_image: Optional[Union[str, torch.Tensor]] = None, model: Optional[nn.Module] = None, - multi_layer_encoder: Optional[enc.MultiLayerEncoder] = None, - content_loss: Optional[Union[ops.ComparisonOperator, ops.OperatorContainer]] = None, - style_loss: Optional[Union[ops.ComparisonOperator, ops.OperatorContainer]] = None, + backbone: str = "vgg16", + content_layer: str = "relu2_2", + content_weight: float = 1e5, + style_layers: Sequence[str] = ("relu1_2", "relu2_2", "relu3_3", "relu4_3"), + style_weight: float = 1e10, optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, @@ -138,23 +39,22 @@ def __init__( ): self.save_hyperparameters(ignore="style_image") - if isinstance(style_image, str): + if style_image is None: + style_image = self.default_style_image() + elif isinstance(style_image, str): style_image = read_image(style_image) if model is None: - # TODO: import this from pystiche - model = Transformer() - - if multi_layer_encoder is None: - multi_layer_encoder = self.default_multi_layer_encoder() - - if content_loss is None: - content_loss = self.default_content_loss(multi_layer_encoder) - - if style_loss is None: - style_loss = self.default_style_loss(multi_layer_encoder) - - perceptual_loss = loss.PerceptualLoss(content_loss, style_loss) + model = pystiche.demo.transformer() + + perceptual_loss = self._get_perceptual_loss( + backbone=backbone, + content_layer=content_layer, + content_weight=content_weight, + style_layers=style_layers, + style_weight=style_weight, + ) + perceptual_loss.set_style_image(style_image) super().__init__( model=model, @@ -168,57 +68,49 @@ def __init__( ) self.perceptual_loss = perceptual_loss - self.perceptual_loss.set_style_image(style_image) - - def default_multi_layer_encoder(self) -> enc.MultiLayerEncoder: - return enc.vgg16_multi_layer_encoder() - def default_content_loss( - self, multi_layer_encoder: Optional[enc.MultiLayerEncoder] = None - ) -> ops.FeatureReconstructionOperator: - if multi_layer_encoder is None: - multi_layer_encoder = self.default_multi_layer_encoder() - content_layer = "relu2_2" - content_encoder = multi_layer_encoder.extract_encoder(content_layer) - content_weight = 1e5 - return ops.FeatureReconstructionOperator(content_encoder, score_weight=content_weight) - - def default_style_loss( - self, multi_layer_encoder: Optional[enc.MultiLayerEncoder] = None - ) -> ops.MultiLayerEncodingOperator: + def default_style_image(self) -> torch.Tensor: + return pystiche.demo.images()["paint"].read(size=256) + @staticmethod + def _modified_gram_loss(encoder: enc.Encoder, *, score_weight: float) -> ops.EncodingComparisonOperator: + # TODO class GramOperator(ops.GramOperator): - def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor: repr = super().enc_to_repr(enc) num_channels = repr.size()[1] return repr / num_channels - if multi_layer_encoder is None: - multi_layer_encoder = self.default_multi_layer_encoder() + return GramOperator(encoder, score_weight=score_weight) - style_layers = ("relu1_2", "relu2_2", "relu3_3", "relu4_3") - style_weight = 1e10 - return ops.MultiLayerEncodingOperator( - multi_layer_encoder, + def _get_perceptual_loss( + self, + *, + backbone: str, + content_layer: str, + content_weight: float, + style_layers: Sequence[str], + style_weight: float, + ) -> loss.PerceptualLoss: + mle, _ = cast(enc.MultiLayerEncoder, self.backbones.get(backbone)()) + content_loss = ops.FeatureReconstructionOperator( + mle.extract_encoder(content_layer), score_weight=content_weight + ) + style_loss = ops.MultiLayerEncodingOperator( + mle, style_layers, - lambda encoder, layer_weight: GramOperator(encoder, score_weight=layer_weight), + lambda encoder, layer_weight: self._modified_gram_loss(encoder, score_weight=layer_weight), layer_weights="sum", score_weight=style_weight, ) + return loss.PerceptualLoss(content_loss, style_loss) - def step(self, batch: Any, batch_idx: int) -> Any: - input_image = batch + def training_step(self, batch: Any, batch_idx: int) -> Any: + input_image = batch[DefaultDataKeys.INPUT] self.perceptual_loss.set_content_image(input_image) - output_image = self(batch) - loss = self.perceptual_loss(output_image).total() - - logs = dict(perceptual_loss=loss) - return dict(loss=loss, logs=logs) - - def training_step(self, batch: Any, batch_idx: int) -> Any: - return super().training_step(batch[DefaultDataKeys.INPUT], batch_idx) + output_image = self(input_image) + return self.perceptual_loss(output_image).total() def validation_step(self, batch: Any, batch_idx: int) -> NoReturn: raise_not_supported("validation") From de6299636b5ac35a0da39d42c313d527dc25607e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 14 May 2021 20:51:12 +0200 Subject: [PATCH 29/60] add explanation for modified gram operator --- flash/vision/style_transfer/model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flash/vision/style_transfer/model.py b/flash/vision/style_transfer/model.py index a27bb8c94b..c92077635b 100644 --- a/flash/vision/style_transfer/model.py +++ b/flash/vision/style_transfer/model.py @@ -74,7 +74,9 @@ def default_style_image(self) -> torch.Tensor: @staticmethod def _modified_gram_loss(encoder: enc.Encoder, *, score_weight: float) -> ops.EncodingComparisonOperator: - # TODO + # The official PyTorch examples as well as the reference implementation of the original author contain an + # oversight: they normalize the representation twice by the number of channels. To be compatible with them, we + # do the same here. class GramOperator(ops.GramOperator): def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor: repr = super().enc_to_repr(enc) From 02a0c29850948ea932b9ca558983372cc0857c55 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 14 May 2021 18:54:58 +0000 Subject: [PATCH 30/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flash/vision/style_transfer/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flash/vision/style_transfer/model.py b/flash/vision/style_transfer/model.py index c92077635b..012d6f6398 100644 --- a/flash/vision/style_transfer/model.py +++ b/flash/vision/style_transfer/model.py @@ -78,6 +78,7 @@ def _modified_gram_loss(encoder: enc.Encoder, *, score_weight: float) -> ops.Enc # oversight: they normalize the representation twice by the number of channels. To be compatible with them, we # do the same here. class GramOperator(ops.GramOperator): + def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor: repr = super().enc_to_repr(enc) num_channels = repr.size()[1] From 21d0817e3debc62774ac57b1d6c2759c4a2f28bf Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Sun, 16 May 2021 14:44:35 +0200 Subject: [PATCH 31/60] streamline default transforms --- flash/core/data/data_pipeline.py | 4 +-- flash/image/style_transfer/data.py | 52 +++++++++++++----------------- 2 files changed, 24 insertions(+), 32 deletions(-) diff --git a/flash/core/data/data_pipeline.py b/flash/core/data/data_pipeline.py index d768050c5d..64dff4d948 100644 --- a/flash/core/data/data_pipeline.py +++ b/flash/core/data/data_pipeline.py @@ -180,7 +180,7 @@ def _resolve_function_hierarchy( if object_type is None: object_type = Preprocess - prefixes = [''] + prefixes = [None] if stage in (RunningStage.TRAINING, RunningStage.TUNING): prefixes += ['train', 'fit'] elif stage == RunningStage.VALIDATING: @@ -192,7 +192,7 @@ def _resolve_function_hierarchy( for prefix in prefixes: if cls._is_overriden(function_name, process_obj, object_type, prefix=prefix): - return f'{prefix}_{function_name}' + return function_name if prefix is None else f'{prefix}_{function_name}' return function_name diff --git a/flash/image/style_transfer/data.py b/flash/image/style_transfer/data.py index 5eaffdd7e1..7d7578f57a 100644 --- a/flash/image/style_transfer/data.py +++ b/flash/image/style_transfer/data.py @@ -19,8 +19,11 @@ def _apply_to_input(default_transforms_fn) -> Callable[..., Dict[str, ApplyToKeys]]: @functools.wraps(default_transforms_fn) - def wrapper(*args: Any, **kwargs: Any) -> Dict[str, ApplyToKeys]: + def wrapper(*args: Any, **kwargs: Any) -> Optional[Dict[str, ApplyToKeys]]: default_transforms = default_transforms_fn(*args, **kwargs) + if not default_transforms: + return default_transforms + return {hook: ApplyToKeys(DefaultDataKeys.INPUT, transform) for hook, transform in default_transforms.items()} return wrapper @@ -51,35 +54,24 @@ def __init__( image_size=image_size, ) - @property - @_apply_to_input - def default_train_transforms(self) -> Dict[str, Callable]: - return dict( - to_tensor_transform=torchvision.transforms.ToTensor(), - per_sample_transform_on_device=nn.Sequential( - transforms.Resize(min(self.image_size)), - transforms.CenterCrop(self.image_size), - ), - ) - - @property - def default_val_transforms(self) -> None: - # Style transfer doesn't support a validation phase, so we return nothing here - return None - - @property - def default_test_transforms(self) -> None: - # Style transfer doesn't support a test phase, so we return nothing here - return None - - @property @_apply_to_input - def default_predict_transforms(self) -> Dict[str, Callable]: - - return dict( - to_tensor_transform=torchvision.transforms.ToTensor(), - per_sample_transform_on_device=transforms.Resize(min(self.image_size)), - ) + def default_transforms(self) -> Optional[Dict[str, Callable]]: + if self.training: + return dict( + to_tensor_transform=torchvision.transforms.ToTensor(), + per_sample_transform_on_device=nn.Sequential( + transforms.Resize(min(self.image_size)), + transforms.CenterCrop(self.image_size), + ), + ) + elif self.predicting: + return dict( + to_tensor_transform=torchvision.transforms.ToTensor(), + per_sample_transform_on_device=transforms.Resize(min(self.image_size)), + ) + else: + # Style transfer doesn't support a validation or test phase, so we return nothing here + return None class StyleTransferData(ImageClassificationData): @@ -103,7 +95,7 @@ def from_folders( preprocess = preprocess or cls.preprocess_cls(train_transform, predict_transform) return cls.from_data_source( - DefaultDataSources.PATHS, + DefaultDataSources.FOLDERS, train_data=train_folder, predict_data=predict_folder, preprocess=preprocess, From f868489dfc9043cd09c48f091ab7fcc4ed790863 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Sun, 16 May 2021 15:26:55 +0200 Subject: [PATCH 32/60] add disabled test for finetuning example --- tests/examples/test_scripts.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index 7564a98941..851cbc6554 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -94,6 +94,11 @@ def run_test(filepath): "translation.py", marks=pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed") ), + # pytest.param( + # "finetuning", + # "style_transfer.py", + # marks=pytest.mark.skipif(not _PYSTICHE_AVAILABLE, reason="pystiche is not installed") + # ), # TODO: takes too long pytest.param( "predict", "image_classification.py", From 39baffd62815890c4c61a9991d72d16302da3619 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Sun, 16 May 2021 15:37:18 +0200 Subject: [PATCH 33/60] add documentation skeleton --- docs/source/index.rst | 1 + docs/source/reference/style_transfer.rst | 35 ++++++++++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 docs/source/reference/style_transfer.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 49de343ea3..abdd2b6f8d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -30,6 +30,7 @@ Lightning Flash reference/object_detection reference/video_classification reference/semantic_segmentation + reference/style_transfer .. toctree:: :maxdepth: 1 diff --git a/docs/source/reference/style_transfer.rst b/docs/source/reference/style_transfer.rst new file mode 100644 index 0000000000..0d60c1b3b1 --- /dev/null +++ b/docs/source/reference/style_transfer.rst @@ -0,0 +1,35 @@ +############## +Style Transfer +############## + +******** +The task +******** + +TODO + +------ + +*** +Fit +*** + +TODO + +------ + +************* +API reference +************* + +StyleTransfer +------------- + +.. autoclass:: flash.image.StyleTransfer + :members: + :exclude-members: forward + +StyleTransferData +----------------- + +.. autoclass:: flash.image.StyleTransferData From 7660f5154649c3ca4338be8c66923dbd1398834f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Sun, 16 May 2021 15:42:17 +0200 Subject: [PATCH 34/60] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f383676bf..588523b910 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Refactor preprocess_cls to preprocess, add Serializer, add DataPipelineState ([#229](https://github.com/PyTorchLightning/lightning-flash/pull/229)) - Added Semantic Segmentation task ([#239](https://github.com/PyTorchLightning/lightning-flash/pull/239) [#287](https://github.com/PyTorchLightning/lightning-flash/pull/287) [#290](https://github.com/PyTorchLightning/lightning-flash/pull/290)) - Added Object detection prediction example ([#283](https://github.com/PyTorchLightning/lightning-flash/pull/283)) +- Added Style Transfer task and accompanying finetuning and prediction examples ([#262](https://github.com/PyTorchLightning/lightning-flash/pull/262)) ### Changed From 46ff9d0c263c1533dd859838adf35e3f54e03185 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 17 May 2021 11:28:43 +0100 Subject: [PATCH 35/60] update --- .github/workflows/ci-testing.yml | 4 ++++ flash/core/data/process.py | 2 +- flash/core/data/transforms.py | 5 ++++ flash/core/utilities/imports.py | 2 ++ flash/image/style_transfer/data.py | 23 +++++++++++-------- flash/image/style_transfer/model.py | 19 +++++++++++---- .../style_transfer/{_utils.py => utils.py} | 0 flash_examples/finetuning/style_transfer.py | 10 ++++---- .../predict/semantic_segmentation.py | 4 +++- flash_examples/predict/style_transfer.py | 21 +++++++++++++++++ requirements/datatype_image.txt | 2 -- .../datatype_image_style_transfer.txt | 3 +++ setup.py | 5 +++- tests/examples/test_scripts.py | 13 +++++++---- 14 files changed, 84 insertions(+), 29 deletions(-) rename flash/image/style_transfer/{_utils.py => utils.py} (100%) create mode 100644 flash_examples/predict/style_transfer.py create mode 100644 requirements/datatype_image_style_transfer.txt diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 70e4465cbe..179a00a051 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -37,6 +37,10 @@ jobs: python-version: 3.8 requires: 'latest' topic: 'text' + - os: ubuntu-20.04 + python-version: 3.8 + requires: 'latest' + topic: 'image_style_transfer' # Timeout: https://stackoverflow.com/a/59076067/4521646 timeout-minutes: 35 diff --git a/flash/core/data/process.py b/flash/core/data/process.py index 30783510b6..7f3c0a2043 100644 --- a/flash/core/data/process.py +++ b/flash/core/data/process.py @@ -24,7 +24,7 @@ from flash.core.data.batch import default_uncollate from flash.core.data.callback import FlashCallback -from flash.core.data.data_source import DataSource +from flash.core.data.data_source import DataSource, DefaultDataKeys from flash.core.data.properties import Properties from flash.core.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX, convert_to_modules, CurrentRunningStageFuncContext diff --git a/flash/core/data/transforms.py b/flash/core/data/transforms.py index d80ec15e69..cb48f495e7 100644 --- a/flash/core/data/transforms.py +++ b/flash/core/data/transforms.py @@ -53,6 +53,11 @@ def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]: return result return x + def __repr__(self): + keys = self.keys[0] if len(self.keys) == 1 else self.keys + transform = [c for c in self.children()] + return f"{self.__class__.__name__}(keys={keys}, transform={transform})" + class KorniaParallelTransforms(nn.Sequential): """The ``KorniaParallelTransforms`` class is an ``nn.Sequential`` which will apply the given transforms to each diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 522482ae18..e5999ac778 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -78,7 +78,9 @@ def _compare_version(package: str, op, version) -> bool: if Version: _TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0") + _PYSTICHE_GREATER_EQUAL_0_7_2 = _compare_version("pystiche", operator.ge, "0.7.2") +_IMAGE_STLYE_TRANSFER = _PYSTICHE_AVAILABLE _TEXT_AVAILABLE = _TRANSFORMERS_AVAILABLE _TABULAR_AVAILABLE = _TABNET_AVAILABLE and _PANDAS_AVAILABLE _VIDEO_AVAILABLE = _PYTORCHVIDEO_AVAILABLE diff --git a/flash/image/style_transfer/data.py b/flash/image/style_transfer/data.py index 7d7578f57a..60cabe7c0b 100644 --- a/flash/image/style_transfer/data.py +++ b/flash/image/style_transfer/data.py @@ -1,6 +1,6 @@ import functools import pathlib -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import torchvision from torch import nn @@ -10,13 +10,13 @@ from flash.core.data.process import Preprocess from flash.core.data.transforms import ApplyToKeys from flash.image.classification import ImageClassificationData, ImageClassificationPreprocess - -from ._utils import raise_not_supported +from flash.image.style_transfer.utils import raise_not_supported __all__ = ["StyleTransferPreprocess", "StyleTransferData"] -def _apply_to_input(default_transforms_fn) -> Callable[..., Dict[str, ApplyToKeys]]: +def _apply_to_input(default_transforms_fn, keys: Union[Sequence[DefaultDataKeys], + DefaultDataKeys]) -> Callable[..., Dict[str, ApplyToKeys]]: @functools.wraps(default_transforms_fn) def wrapper(*args: Any, **kwargs: Any) -> Optional[Dict[str, ApplyToKeys]]: @@ -24,7 +24,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Optional[Dict[str, ApplyToKeys]]: if not default_transforms: return default_transforms - return {hook: ApplyToKeys(DefaultDataKeys.INPUT, transform) for hook, transform in default_transforms.items()} + return {hook: ApplyToKeys(keys, transform) for hook, transform in default_transforms.items()} return wrapper @@ -54,20 +54,20 @@ def __init__( image_size=image_size, ) - @_apply_to_input + @functools.partial(_apply_to_input, keys=DefaultDataKeys.INPUT) def default_transforms(self) -> Optional[Dict[str, Callable]]: if self.training: return dict( to_tensor_transform=torchvision.transforms.ToTensor(), per_sample_transform_on_device=nn.Sequential( - transforms.Resize(min(self.image_size)), + transforms.Resize(self.image_size), transforms.CenterCrop(self.image_size), ), ) elif self.predicting: return dict( + pre_tensor_transform=transforms.Resize(self.image_size), to_tensor_transform=torchvision.transforms.ToTensor(), - per_sample_transform_on_device=transforms.Resize(min(self.image_size)), ) else: # Style transfer doesn't support a validation or test phase, so we return nothing here @@ -87,12 +87,17 @@ def from_folders( preprocess: Optional[Preprocess] = None, **kwargs: Any, ) -> "StyleTransferData": + if any(param in kwargs for param in ("val_folder", "val_transform")): raise_not_supported("validation") + if any(param in kwargs for param in ("test_folder", "test_transform")): raise_not_supported("test") - preprocess = preprocess or cls.preprocess_cls(train_transform, predict_transform) + preprocess = preprocess or cls.preprocess_cls( + train_transform=train_transform, + predict_transform=predict_transform, + ) return cls.from_data_source( DefaultDataSources.FOLDERS, diff --git a/flash/image/style_transfer/model.py b/flash/image/style_transfer/model.py index 7ba3957557..9720488717 100644 --- a/flash/image/style_transfer/model.py +++ b/flash/image/style_transfer/model.py @@ -1,9 +1,6 @@ from typing import Any, cast, Dict, Mapping, NoReturn, Optional, Sequence, Type, Union -import pystiche.demo import torch -from pystiche import enc, loss, ops -from pystiche.image import read_image from torch import nn from torch.optim.lr_scheduler import _LRScheduler @@ -11,9 +8,15 @@ from flash.core.data.process import Serializer from flash.core.model import Task from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _IMAGE_STLYE_TRANSFER from flash.image.style_transfer import STYLE_TRANSFER_BACKBONES -from ._utils import raise_not_supported +if _IMAGE_STLYE_TRANSFER: + import pystiche.demo + from pystiche import enc, loss, ops + from pystiche.image import read_image + +from flash.image.style_transfer.utils import raise_not_supported __all__ = ["StyleTransfer"] @@ -37,6 +40,9 @@ def __init__( learning_rate: float = 1e-3, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, ): + if not _IMAGE_STLYE_TRANSFER: + raise ModuleNotFoundError("Please, pip install -e '.[image_style_transfer]'") + self.save_hyperparameters(ignore="style_image") if style_image is None: @@ -111,7 +117,6 @@ def _get_perceptual_loss( def training_step(self, batch: Any, batch_idx: int) -> Any: input_image = batch[DefaultDataKeys.INPUT] self.perceptual_loss.set_content_image(input_image) - output_image = self(input_image) return self.perceptual_loss(output_image).total() @@ -120,3 +125,7 @@ def validation_step(self, batch: Any, batch_idx: int) -> NoReturn: def test_step(self, batch: Any, batch_idx: int) -> NoReturn: raise_not_supported("test") + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Any: + input_image = batch[DefaultDataKeys.INPUT] + return self(input_image) diff --git a/flash/image/style_transfer/_utils.py b/flash/image/style_transfer/utils.py similarity index 100% rename from flash/image/style_transfer/_utils.py rename to flash/image/style_transfer/utils.py diff --git a/flash_examples/finetuning/style_transfer.py b/flash_examples/finetuning/style_transfer.py index 57ebf9ad9c..31b49ad6b2 100644 --- a/flash_examples/finetuning/style_transfer.py +++ b/flash_examples/finetuning/style_transfer.py @@ -6,15 +6,15 @@ if _PYSTICHE_AVAILABLE: import pystiche.demo + + from flash.image.style_transfer import StyleTransfer, StyleTransferData else: print("Please, run `pip install pystiche`") - sys.exit(0) - -from flash.image.style_transfer import StyleTransfer, StyleTransferData + sys.exit(1) -download_data("http://images.cocodataset.org/zips/train2014.zip", "data") +download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") -data_module = StyleTransferData.from_folders(train_folder="data", batch_size=4) +data_module = StyleTransferData.from_folders(train_folder="data/coco128/images", batch_size=4) style_image = pystiche.demo.images()["paint"].read(size=256) diff --git a/flash_examples/predict/semantic_segmentation.py b/flash_examples/predict/semantic_segmentation.py index 41bd89654a..3938ffe17a 100644 --- a/flash_examples/predict/semantic_segmentation.py +++ b/flash_examples/predict/semantic_segmentation.py @@ -24,7 +24,9 @@ ) # 2. Load the model from a checkpoint -model = SemanticSegmentation.load_from_checkpoint("semantic_segmentation_model.pt") +model = SemanticSegmentation.load_from_checkpoint( + "https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt" +) model.serializer = SegmentationLabels(visualize=True) # 3. Predict what's on a few images and visualize! diff --git a/flash_examples/predict/style_transfer.py b/flash_examples/predict/style_transfer.py new file mode 100644 index 0000000000..3a00e5d63f --- /dev/null +++ b/flash_examples/predict/style_transfer.py @@ -0,0 +1,21 @@ +import sys + +import flash +from flash.core.data.utils import download_data +from flash.core.utilities.imports import _PYSTICHE_AVAILABLE +from flash.image.style_transfer import StyleTransfer, StyleTransferData + +if not _PYSTICHE_AVAILABLE: + print("Please, run `pip install pystiche`") + sys.exit(1) + +from flash.image.style_transfer import StyleTransfer, StyleTransferData + +download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") + +model = StyleTransfer.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/style_transfer_model.pt") + +datamodule = StyleTransferData.from_folders(predict_folder="data/coco128/images", batch_size=4) + +trainer = flash.Trainer(max_epochs=2) +trainer.predict(model, datamodule=datamodule) diff --git a/requirements/datatype_image.txt b/requirements/datatype_image.txt index b7258131bf..c3e9f4ec54 100644 --- a/requirements/datatype_image.txt +++ b/requirements/datatype_image.txt @@ -5,5 +5,3 @@ Pillow>=7.2 kornia>=0.5.1 matplotlib pycocotools>=2.0.2 ; python_version >= "3.7" -git+https://github.com/pystiche/pystiche@releases/v0.7.2 -# pystiche>=0.7.2 diff --git a/requirements/datatype_image_style_transfer.txt b/requirements/datatype_image_style_transfer.txt new file mode 100644 index 0000000000..dd830e0f0c --- /dev/null +++ b/requirements/datatype_image_style_transfer.txt @@ -0,0 +1,3 @@ +torchvision +kornia>=0.5.1 +pystiche diff --git a/setup.py b/setup.py index df687a01f5..a7e53d8826 100644 --- a/setup.py +++ b/setup.py @@ -32,11 +32,14 @@ def _load_py_module(fname, pkg="flash"): "text": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_text.txt"), "tabular": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_tabular.txt"), "image": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_image.txt"), + "image_style_transfer": setup_tools._load_requirements( + path_dir=_PATH_REQUIRE, file_name="datatype_image_style_transfer.txt" + ), "video": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_video.txt"), } # remove possible duplicate. -extras["vision"] = list(set(extras["image"] + extras["video"])) +extras["vision"] = list(set(extras["image"] + extras["video"] + extras["image_style_transfer"])) extras["dev"] = list(set(extras["vision"] + extras["tabular"] + extras["text"] + extras["image"])) extras["dev-test"] = list(set(extras["test"] + extras["dev"])) extras["all"] = list(set(extras["dev"] + extras["docs"])) diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index 851cbc6554..24d93621d9 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -22,6 +22,8 @@ from flash.core.utilities.imports import ( _IMAGE_AVAILABLE, + _IMAGE_STLYE_TRANSFER, + _PYSTICHE_GREATER_EQUAL_0_7_2, _TABULAR_AVAILABLE, _TEXT_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_9, @@ -29,6 +31,7 @@ ) _IMAGE_AVAILABLE = _IMAGE_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_9 +_IMAGE_STLYE_TRANSFER = _IMAGE_STLYE_TRANSFER and _PYSTICHE_GREATER_EQUAL_0_7_2 root = Path(__file__).parent.parent.parent @@ -94,11 +97,11 @@ def run_test(filepath): "translation.py", marks=pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed") ), - # pytest.param( - # "finetuning", - # "style_transfer.py", - # marks=pytest.mark.skipif(not _PYSTICHE_AVAILABLE, reason="pystiche is not installed") - # ), # TODO: takes too long + pytest.param( + "finetuning", + "style_transfer.py", + marks=pytest.mark.skipif(not _IMAGE_STLYE_TRANSFER, reason="pystiche is not installed") + ), pytest.param( "predict", "image_classification.py", From e230a2c2b3248fc137e47d9deb79eab5350f61d6 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 17 May 2021 11:36:30 +0100 Subject: [PATCH 36/60] update --- flash/image/style_transfer/backbone.py | 33 ++++++++++++----------- flash/image/style_transfer/model.py | 36 ++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 15 deletions(-) diff --git a/flash/image/style_transfer/backbone.py b/flash/image/style_transfer/backbone.py index dfcd883af3..84bef9c18d 100644 --- a/flash/image/style_transfer/backbone.py +++ b/flash/image/style_transfer/backbone.py @@ -1,23 +1,26 @@ import re -from pystiche import enc - from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _PYSTICHE_AVAILABLE + +STYLE_TRANSFER_BACKBONES = FlashRegistry("backbones") __all__ = ["STYLE_TRANSFER_BACKBONES"] -MLE_FN_PATTERN = re.compile(r"^(?P\w+?)_multi_layer_encoder$") +if _PYSTICHE_AVAILABLE: -STYLE_TRANSFER_BACKBONES = FlashRegistry("backbones") + from pystiche import enc + + MLE_FN_PATTERN = re.compile(r"^(?P\w+?)_multi_layer_encoder$") + + for mle_fn in dir(enc): + match = MLE_FN_PATTERN.match(mle_fn) + if not match: + continue -for mle_fn in dir(enc): - match = MLE_FN_PATTERN.match(mle_fn) - if not match: - continue - - STYLE_TRANSFER_BACKBONES( - fn=lambda: (getattr(enc, mle_fn)(), None), - name=match.group("name"), - namespace="image/style_transfer", - package="pystiche", - ) + STYLE_TRANSFER_BACKBONES( + fn=lambda: (getattr(enc, mle_fn)(), None), + name=match.group("name"), + namespace="image/style_transfer", + package="pystiche", + ) diff --git a/flash/image/style_transfer/model.py b/flash/image/style_transfer/model.py index 9720488717..fc6d02dddb 100644 --- a/flash/image/style_transfer/model.py +++ b/flash/image/style_transfer/model.py @@ -22,6 +22,41 @@ class StyleTransfer(Task): + """Task that transfer the style from an image onto another. + + Use a built in backbone + + Example:: + + from flash.image import ImageClassifier + + classifier = ImageClassifier(backbone='resnet18') + + Or your own backbone (num_features is the number of features produced by your backbone) + + Example:: + + from flash.image import ImageClassifier + from torch import nn + + # use any backbone + some_backbone = nn.Conv2D(...) + num_out_features = 1024 + classifier = ImageClassifier(backbone=(some_backbone, num_out_features)) + + + Args: + num_classes: Number of classes to classify. + backbone: A string or (model, num_features) tuple to use to compute image features, defaults to ``"resnet18"``. + pretrained: Use a pretrained backbone, defaults to ``True``. + loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`. + optimizer: Optimizer to use for training, defaults to :class:`torch.optim.SGD`. + metrics: Metrics to compute for training and evaluation, defaults to :class:`torchmetrics.Accuracy`. + learning_rate: Learning rate to use for training, defaults to ``1e-3``. + multi_label: Whether the targets are multi-label or not. + serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs. + """ + backbones: FlashRegistry = STYLE_TRANSFER_BACKBONES def __init__( @@ -40,6 +75,7 @@ def __init__( learning_rate: float = 1e-3, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, ): + if not _IMAGE_STLYE_TRANSFER: raise ModuleNotFoundError("Please, pip install -e '.[image_style_transfer]'") From 40fd08b72e546dafe2f255b3fd05b7384270a4f0 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 17 May 2021 12:22:53 +0100 Subject: [PATCH 37/60] update --- flash/image/style_transfer/backbone.py | 12 +++--- flash/image/style_transfer/model.py | 43 ++++++++----------- tests/{vision => image}/__init__.py | 0 .../classification/__init__.py | 0 .../classification/test_data.py | 0 .../test_data_model_integration.py | 0 .../classification/test_model.py | 0 tests/{vision => image}/detection/__init__.py | 0 .../{vision => image}/detection/test_data.py | 0 .../detection/test_data_model_integration.py | 0 .../{vision => image}/detection/test_model.py | 0 .../segmentation/__init__.py | 0 .../segmentation/test_data.py | 0 .../segmentation/test_model.py | 0 .../segmentation/test_serialization.py | 0 tests/image/style_transfer/test_model.py | 22 ++++++++++ tests/{vision => image}/test_backbones.py | 0 17 files changed, 44 insertions(+), 33 deletions(-) rename tests/{vision => image}/__init__.py (100%) rename tests/{vision => image}/classification/__init__.py (100%) rename tests/{vision => image}/classification/test_data.py (100%) rename tests/{vision => image}/classification/test_data_model_integration.py (100%) rename tests/{vision => image}/classification/test_model.py (100%) rename tests/{vision => image}/detection/__init__.py (100%) rename tests/{vision => image}/detection/test_data.py (100%) rename tests/{vision => image}/detection/test_data_model_integration.py (100%) rename tests/{vision => image}/detection/test_model.py (100%) rename tests/{vision => image}/segmentation/__init__.py (100%) rename tests/{vision => image}/segmentation/test_data.py (100%) rename tests/{vision => image}/segmentation/test_model.py (100%) rename tests/{vision => image}/segmentation/test_serialization.py (100%) create mode 100644 tests/image/style_transfer/test_model.py rename tests/{vision => image}/test_backbones.py (100%) diff --git a/flash/image/style_transfer/backbone.py b/flash/image/style_transfer/backbone.py index 84bef9c18d..95d5e40024 100644 --- a/flash/image/style_transfer/backbone.py +++ b/flash/image/style_transfer/backbone.py @@ -1,5 +1,3 @@ -import re - from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _PYSTICHE_AVAILABLE @@ -11,16 +9,16 @@ from pystiche import enc - MLE_FN_PATTERN = re.compile(r"^(?P\w+?)_multi_layer_encoder$") - for mle_fn in dir(enc): - match = MLE_FN_PATTERN.match(mle_fn) - if not match: + + if not "multi_layer_encoder" in mle_fn: continue + name = mle_fn.split("_")[0] + STYLE_TRANSFER_BACKBONES( fn=lambda: (getattr(enc, mle_fn)(), None), - name=match.group("name"), + name=mle_fn.split("_")[0], namespace="image/style_transfer", package="pystiche", ) diff --git a/flash/image/style_transfer/model.py b/flash/image/style_transfer/model.py index fc6d02dddb..c74fc602cd 100644 --- a/flash/image/style_transfer/model.py +++ b/flash/image/style_transfer/model.py @@ -1,4 +1,4 @@ -from typing import Any, cast, Dict, Mapping, NoReturn, Optional, Sequence, Type, Union +from typing import Any, cast, Dict, List, Mapping, NoReturn, Optional, Sequence, Tuple, Type, Union import torch from torch import nn @@ -24,36 +24,24 @@ class StyleTransfer(Task): """Task that transfer the style from an image onto another. - Use a built in backbone - - Example:: - - from flash.image import ImageClassifier - - classifier = ImageClassifier(backbone='resnet18') - - Or your own backbone (num_features is the number of features produced by your backbone) - Example:: - from flash.image import ImageClassifier - from torch import nn - - # use any backbone - some_backbone = nn.Conv2D(...) - num_out_features = 1024 - classifier = ImageClassifier(backbone=(some_backbone, num_out_features)) + from flash.image.style_transfer import StyleTransfer + model = StyleTransfer(image_style) Args: - num_classes: Number of classes to classify. - backbone: A string or (model, num_features) tuple to use to compute image features, defaults to ``"resnet18"``. - pretrained: Use a pretrained backbone, defaults to ``True``. - loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`. - optimizer: Optimizer to use for training, defaults to :class:`torch.optim.SGD`. - metrics: Metrics to compute for training and evaluation, defaults to :class:`torchmetrics.Accuracy`. + style_image: Image or path to an image to derive the style from. + model: The model by the style transfer task. + backbone: A string or model to use to compute the style loss from. + content_layer: Which layer from the backbone to extract the content loss from. + content_weight: The weight associated with the content loss. A lower value will lose content over style. + style_layers: Layers from the backbone to derive the style loss from. + optimizer: Optimizer to use for training the model. + optimizer_kwargs: Optimizer keywords arguments. + scheduler: Scheduler to use for training the model. + scheduler_kwargs: Scheduler keywords arguments. learning_rate: Learning rate to use for training, defaults to ``1e-3``. - multi_label: Whether the targets are multi-label or not. serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs. """ @@ -66,7 +54,7 @@ def __init__( backbone: str = "vgg16", content_layer: str = "relu2_2", content_weight: float = 1e5, - style_layers: Sequence[str] = ("relu1_2", "relu2_2", "relu3_3", "relu4_3"), + style_layers: Union[Sequence[str], str] = ("relu1_2", "relu2_2", "relu3_3", "relu4_3"), style_weight: float = 1e10, optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, @@ -89,6 +77,9 @@ def __init__( if model is None: model = pystiche.demo.transformer() + if not isinstance(style_layers, (List, Tuple)): + style_layers = (style_layers, ) + perceptual_loss = self._get_perceptual_loss( backbone=backbone, content_layer=content_layer, diff --git a/tests/vision/__init__.py b/tests/image/__init__.py similarity index 100% rename from tests/vision/__init__.py rename to tests/image/__init__.py diff --git a/tests/vision/classification/__init__.py b/tests/image/classification/__init__.py similarity index 100% rename from tests/vision/classification/__init__.py rename to tests/image/classification/__init__.py diff --git a/tests/vision/classification/test_data.py b/tests/image/classification/test_data.py similarity index 100% rename from tests/vision/classification/test_data.py rename to tests/image/classification/test_data.py diff --git a/tests/vision/classification/test_data_model_integration.py b/tests/image/classification/test_data_model_integration.py similarity index 100% rename from tests/vision/classification/test_data_model_integration.py rename to tests/image/classification/test_data_model_integration.py diff --git a/tests/vision/classification/test_model.py b/tests/image/classification/test_model.py similarity index 100% rename from tests/vision/classification/test_model.py rename to tests/image/classification/test_model.py diff --git a/tests/vision/detection/__init__.py b/tests/image/detection/__init__.py similarity index 100% rename from tests/vision/detection/__init__.py rename to tests/image/detection/__init__.py diff --git a/tests/vision/detection/test_data.py b/tests/image/detection/test_data.py similarity index 100% rename from tests/vision/detection/test_data.py rename to tests/image/detection/test_data.py diff --git a/tests/vision/detection/test_data_model_integration.py b/tests/image/detection/test_data_model_integration.py similarity index 100% rename from tests/vision/detection/test_data_model_integration.py rename to tests/image/detection/test_data_model_integration.py diff --git a/tests/vision/detection/test_model.py b/tests/image/detection/test_model.py similarity index 100% rename from tests/vision/detection/test_model.py rename to tests/image/detection/test_model.py diff --git a/tests/vision/segmentation/__init__.py b/tests/image/segmentation/__init__.py similarity index 100% rename from tests/vision/segmentation/__init__.py rename to tests/image/segmentation/__init__.py diff --git a/tests/vision/segmentation/test_data.py b/tests/image/segmentation/test_data.py similarity index 100% rename from tests/vision/segmentation/test_data.py rename to tests/image/segmentation/test_data.py diff --git a/tests/vision/segmentation/test_model.py b/tests/image/segmentation/test_model.py similarity index 100% rename from tests/vision/segmentation/test_model.py rename to tests/image/segmentation/test_model.py diff --git a/tests/vision/segmentation/test_serialization.py b/tests/image/segmentation/test_serialization.py similarity index 100% rename from tests/vision/segmentation/test_serialization.py rename to tests/image/segmentation/test_serialization.py diff --git a/tests/image/style_transfer/test_model.py b/tests/image/style_transfer/test_model.py new file mode 100644 index 0000000000..51ff23f563 --- /dev/null +++ b/tests/image/style_transfer/test_model.py @@ -0,0 +1,22 @@ +import pytest + +from flash.core.utilities.imports import _IMAGE_STLYE_TRANSFER +from flash.image.style_transfer import StyleTransfer + + +@pytest.mark.skipif(not _IMAGE_STLYE_TRANSFER, reason="image style transfer libraries aren't installed.") +def test_style_transfer_task(): + + model = StyleTransfer( + backbone="vgg11", content_layer="relu1_2", content_weight=10, style_layers="relu1_2", style_weight=11 + ) + assert model.perceptual_loss.content_loss.encoder.layer == "relu1_2" + assert model.perceptual_loss.content_loss.score_weight == 10 + assert "relu1_2" in [n for n, m in model.perceptual_loss.style_loss.named_modules()] + assert model.perceptual_loss.style_loss.score_weight == 11 + + +@pytest.mark.skipif(_IMAGE_STLYE_TRANSFER, reason="image style transfer libraries are installed.") +def test_style_transfer_task_import(): + with pytest.raises(ModuleNotFoundError, match="[image_style_transfer]"): + StyleTransfer() diff --git a/tests/vision/test_backbones.py b/tests/image/test_backbones.py similarity index 100% rename from tests/vision/test_backbones.py rename to tests/image/test_backbones.py From 278a87400f16d9adbef9ecc6352d147f42769345 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 17 May 2021 12:25:38 +0100 Subject: [PATCH 38/60] update --- flash/image/style_transfer/data.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/flash/image/style_transfer/data.py b/flash/image/style_transfer/data.py index 60cabe7c0b..dc1f975cba 100644 --- a/flash/image/style_transfer/data.py +++ b/flash/image/style_transfer/data.py @@ -2,16 +2,18 @@ import pathlib from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union -import torchvision from torch import nn -from torchvision import transforms from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources from flash.core.data.process import Preprocess from flash.core.data.transforms import ApplyToKeys +from flash.core.utilities.imports import _TORCHVISION_AVAILABLE from flash.image.classification import ImageClassificationData, ImageClassificationPreprocess from flash.image.style_transfer.utils import raise_not_supported +if _TORCHVISION_AVAILABLE: + from torchvision import transforms as T + __all__ = ["StyleTransferPreprocess", "StyleTransferData"] @@ -58,16 +60,16 @@ def __init__( def default_transforms(self) -> Optional[Dict[str, Callable]]: if self.training: return dict( - to_tensor_transform=torchvision.transforms.ToTensor(), + to_tensor_transform=T.ToTensor(), per_sample_transform_on_device=nn.Sequential( - transforms.Resize(self.image_size), - transforms.CenterCrop(self.image_size), + T.Resize(self.image_size), + T.CenterCrop(self.image_size), ), ) elif self.predicting: return dict( - pre_tensor_transform=transforms.Resize(self.image_size), - to_tensor_transform=torchvision.transforms.ToTensor(), + pre_tensor_transform=T.Resize(self.image_size), + to_tensor_transform=T.ToTensor(), ) else: # Style transfer doesn't support a validation or test phase, so we return nothing here From 740bb2206fe3f9a8063376cdec4f580f48b5d554 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 17 May 2021 12:26:35 +0100 Subject: [PATCH 39/60] update --- requirements/datatype_image_style_transfer.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/datatype_image_style_transfer.txt b/requirements/datatype_image_style_transfer.txt index dd830e0f0c..431e475fe2 100644 --- a/requirements/datatype_image_style_transfer.txt +++ b/requirements/datatype_image_style_transfer.txt @@ -1,3 +1,4 @@ torchvision +Pillow>=7.2 kornia>=0.5.1 pystiche From 3e2ad579fad589588d4c0a2790369061c5fc328f Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 17 May 2021 12:29:37 +0100 Subject: [PATCH 40/60] update --- flash/image/style_transfer/model.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/flash/image/style_transfer/model.py b/flash/image/style_transfer/model.py index c74fc602cd..534184b7c7 100644 --- a/flash/image/style_transfer/model.py +++ b/flash/image/style_transfer/model.py @@ -1,6 +1,7 @@ from typing import Any, cast, Dict, List, Mapping, NoReturn, Optional, Sequence, Tuple, Type, Union import torch +from pystiche.enc.encoder import Encoder from torch import nn from torch.optim.lr_scheduler import _LRScheduler @@ -15,6 +16,17 @@ import pystiche.demo from pystiche import enc, loss, ops from pystiche.image import read_image +else: + + class enc: + Encoder = None + MultiLayerEncoder = None + + class ops: + EncodingComparisonOperator = None + FeatureReconstructionOperator = None + MultiLayerEncodingOperator = None + from flash.image.style_transfer.utils import raise_not_supported From 753dfd7e714cf5e3b304cd5aadf89eda6f2b1103 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 17 May 2021 12:39:45 +0100 Subject: [PATCH 41/60] update --- docs/source/reference/style_transfer.rst | 39 ++++++++++++++++++++++-- flash/image/style_transfer/model.py | 1 - 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/docs/source/reference/style_transfer.rst b/docs/source/reference/style_transfer.rst index 0d60c1b3b1..af67e9a08b 100644 --- a/docs/source/reference/style_transfer.rst +++ b/docs/source/reference/style_transfer.rst @@ -6,7 +6,8 @@ Style Transfer The task ******** -TODO +The Neural Style Transfer Task is an optimization method which extract the style from an image and apply it the content image (target). +The goal is that the output image looks like the content image, but “painted” in the style of the style reference image. ------ @@ -14,7 +15,41 @@ TODO Fit *** -TODO +First, you would have to import the :class:`~flash.image.style_transfer.StyleTransfer` +and :class:`~flash.image.style_transfer.StyleTransferData` from Flash. + +.. testcode:: style_transfer + + import sys + import flash + from flash.core.data.utils import download_data + from flash.image.style_transfer import StyleTransfer, StyleTransferData + + +Then, download some content images and create a DataModule. + +.. testcode:: style_transfer + + download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") + + data_module = StyleTransferData.from_folders(train_folder="data/coco128/images", batch_size=4) + +Select a style image and pass it to the `StyleTransfer` task. + + +.. testcode:: style_transfer + + style_image = pystiche.demo.images()["paint"].read(size=256) + + model = StyleTransfer(style_image) + +Finally, create a Flash Trainer and pass it the model and datamodule. + +.. testcode:: style_transfer + + trainer = flash.Trainer(max_epochs=2) + trainer.fit(model, data_module) + ------ diff --git a/flash/image/style_transfer/model.py b/flash/image/style_transfer/model.py index 534184b7c7..895279ae17 100644 --- a/flash/image/style_transfer/model.py +++ b/flash/image/style_transfer/model.py @@ -1,7 +1,6 @@ from typing import Any, cast, Dict, List, Mapping, NoReturn, Optional, Sequence, Tuple, Type, Union import torch -from pystiche.enc.encoder import Encoder from torch import nn from torch.optim.lr_scheduler import _LRScheduler From a49474ac9f8ec52f7d662ee1145704dc93f94269 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 17 May 2021 12:42:51 +0100 Subject: [PATCH 42/60] update --- docs/source/reference/style_transfer.rst | 3 ++- flash/image/style_transfer/model.py | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/source/reference/style_transfer.rst b/docs/source/reference/style_transfer.rst index af67e9a08b..a16e387066 100644 --- a/docs/source/reference/style_transfer.rst +++ b/docs/source/reference/style_transfer.rst @@ -6,7 +6,7 @@ Style Transfer The task ******** -The Neural Style Transfer Task is an optimization method which extract the style from an image and apply it the content image (target). +The Neural Style Transfer Task is an optimization method which extract the style from an image and apply it another image while preserving its content. The goal is that the output image looks like the content image, but “painted” in the style of the style reference image. ------ @@ -24,6 +24,7 @@ and :class:`~flash.image.style_transfer.StyleTransferData` from Flash. import flash from flash.core.data.utils import download_data from flash.image.style_transfer import StyleTransfer, StyleTransferData + import pystiche Then, download some content images and create a DataModule. diff --git a/flash/image/style_transfer/model.py b/flash/image/style_transfer/model.py index 895279ae17..0fa1e23d9e 100644 --- a/flash/image/style_transfer/model.py +++ b/flash/image/style_transfer/model.py @@ -26,6 +26,9 @@ class ops: FeatureReconstructionOperator = None MultiLayerEncodingOperator = None + class loss: + PerceptualLoss = object + from flash.image.style_transfer.utils import raise_not_supported From 55888c3a4d604cc5a215825add1e9409ce6bb853 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 17 May 2021 12:49:31 +0100 Subject: [PATCH 43/60] update --- docs/source/reference/style_transfer.rst | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/source/reference/style_transfer.rst b/docs/source/reference/style_transfer.rst index a16e387066..5e5b147421 100644 --- a/docs/source/reference/style_transfer.rst +++ b/docs/source/reference/style_transfer.rst @@ -9,6 +9,9 @@ The task The Neural Style Transfer Task is an optimization method which extract the style from an image and apply it another image while preserving its content. The goal is that the output image looks like the content image, but “painted” in the style of the style reference image. +.. image:: https://raw.githubusercontent.com/pystiche/pystiche/master/docs/source/graphics/banner/banner.jpg + :alt: style_transfer_example + ------ *** @@ -27,7 +30,7 @@ and :class:`~flash.image.style_transfer.StyleTransferData` from Flash. import pystiche -Then, download some content images and create a DataModule. +Then, download some content images and create a :class:`~flash.image.style_transfer.StyleTransferData` DataModule. .. testcode:: style_transfer @@ -35,8 +38,8 @@ Then, download some content images and create a DataModule. data_module = StyleTransferData.from_folders(train_folder="data/coco128/images", batch_size=4) -Select a style image and pass it to the `StyleTransfer` task. +Select a style image and pass it to the `StyleTransfer` task. .. testcode:: style_transfer @@ -44,7 +47,7 @@ Select a style image and pass it to the `StyleTransfer` task. model = StyleTransfer(style_image) -Finally, create a Flash Trainer and pass it the model and datamodule. +Finally, create a Flash :class:`flash.core.trainer.Trainer` and pass it the model and datamodule. .. testcode:: style_transfer From 71b63f4865c4dd64b6b1ca7f4fcd5b0cc74c42c6 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 17 May 2021 12:51:32 +0100 Subject: [PATCH 44/60] update --- flash/image/style_transfer/backbone.py | 2 +- flash/image/style_transfer/model.py | 4 +++- flash_examples/predict/style_transfer.py | 2 -- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/flash/image/style_transfer/backbone.py b/flash/image/style_transfer/backbone.py index 95d5e40024..84dc723bdb 100644 --- a/flash/image/style_transfer/backbone.py +++ b/flash/image/style_transfer/backbone.py @@ -11,7 +11,7 @@ for mle_fn in dir(enc): - if not "multi_layer_encoder" in mle_fn: + if "multi_layer_encoder" not in mle_fn: continue name = mle_fn.split("_")[0] diff --git a/flash/image/style_transfer/model.py b/flash/image/style_transfer/model.py index 0fa1e23d9e..e689725cce 100644 --- a/flash/image/style_transfer/model.py +++ b/flash/image/style_transfer/model.py @@ -27,7 +27,9 @@ class ops: MultiLayerEncodingOperator = None class loss: - PerceptualLoss = object + + class PerceptualLoss: + pass from flash.image.style_transfer.utils import raise_not_supported diff --git a/flash_examples/predict/style_transfer.py b/flash_examples/predict/style_transfer.py index 3a00e5d63f..afd9741a29 100644 --- a/flash_examples/predict/style_transfer.py +++ b/flash_examples/predict/style_transfer.py @@ -9,8 +9,6 @@ print("Please, run `pip install pystiche`") sys.exit(1) -from flash.image.style_transfer import StyleTransfer, StyleTransferData - download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") model = StyleTransfer.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/style_transfer_model.pt") From 4ab9aaec58eeeb8d3ca97c25585d3b984ce0f335 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 17 May 2021 13:04:54 +0100 Subject: [PATCH 45/60] update --- flash/core/data/data_pipeline.py | 5 ++++- flash/core/data/data_source.py | 10 +++------- tests/data/test_auto_dataset.py | 3 +++ tests/image/detection/test_data_model_integration.py | 2 +- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/flash/core/data/data_pipeline.py b/flash/core/data/data_pipeline.py index 64dff4d948..0446ba308a 100644 --- a/flash/core/data/data_pipeline.py +++ b/flash/core/data/data_pipeline.py @@ -180,7 +180,8 @@ def _resolve_function_hierarchy( if object_type is None: object_type = Preprocess - prefixes = [None] + prefixes = [] + if stage in (RunningStage.TRAINING, RunningStage.TUNING): prefixes += ['train', 'fit'] elif stage == RunningStage.VALIDATING: @@ -190,6 +191,8 @@ def _resolve_function_hierarchy( elif stage == RunningStage.PREDICTING: prefixes += ['predict'] + prefixes += [None] + for prefix in prefixes: if cls._is_overriden(function_name, process_obj, object_type, prefix=prefix): return function_name if prefix is None else f'{prefix}_{function_name}' diff --git a/flash/core/data/data_source.py b/flash/core/data/data_source.py index ea761b60d7..2f2a7c8bbd 100644 --- a/flash/core/data/data_source.py +++ b/flash/core/data/data_source.py @@ -296,14 +296,10 @@ def generate_dataset( mock_dataset = typing.cast(AutoDataset, MockDataset()) with CurrentRunningStageFuncContext(running_stage, "load_data", self): - load_data: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr( - self, DataPipeline._resolve_function_hierarchy( - "load_data", - self, - running_stage, - DataSource, - ) + resolved_func_name = DataPipeline._resolve_function_hierarchy( + "load_data", self, running_stage, DataSource ) + load_data: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr(self, resolved_func_name) parameters = signature(load_data).parameters if len(parameters) > 1 and "dataset" in parameters: # TODO: This was DATASET_KEY before data = load_data(data, mock_dataset) diff --git a/tests/data/test_auto_dataset.py b/tests/data/test_auto_dataset.py index 7caa4e66ab..d67c06ad20 100644 --- a/tests/data/test_auto_dataset.py +++ b/tests/data/test_auto_dataset.py @@ -177,6 +177,9 @@ def test_preprocessing_data_source_with_running_stage(with_dataset): dataset = data_source.generate_dataset(range(10), running_stage=running_stage) + import pdb + pdb.set_trace() + assert len(dataset) == 10 for idx in range(len(dataset)): diff --git a/tests/image/detection/test_data_model_integration.py b/tests/image/detection/test_data_model_integration.py index 2e195c1b9f..428a053b75 100644 --- a/tests/image/detection/test_data_model_integration.py +++ b/tests/image/detection/test_data_model_integration.py @@ -26,7 +26,7 @@ Image = None if _COCO_AVAILABLE: - from tests.vision.detection.test_data import _create_synth_coco_dataset + from tests.image.detection.test_data import _create_synth_coco_dataset @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="pycocotools is not installed for testing") From 4b1e7b9100e253c9a3f15f8a66f0cffc3c41cd22 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 17 May 2021 13:33:49 +0100 Subject: [PATCH 46/60] update --- tests/data/test_auto_dataset.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/data/test_auto_dataset.py b/tests/data/test_auto_dataset.py index d67c06ad20..7caa4e66ab 100644 --- a/tests/data/test_auto_dataset.py +++ b/tests/data/test_auto_dataset.py @@ -177,9 +177,6 @@ def test_preprocessing_data_source_with_running_stage(with_dataset): dataset = data_source.generate_dataset(range(10), running_stage=running_stage) - import pdb - pdb.set_trace() - assert len(dataset) == 10 for idx in range(len(dataset)): From 5c3e72a07c9d36c140fadab74f0c5376f21a4e02 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 17 May 2021 13:43:23 +0100 Subject: [PATCH 47/60] update --- tests/image/style_transfer/test_model.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/image/style_transfer/test_model.py b/tests/image/style_transfer/test_model.py index 51ff23f563..ca00d3290f 100644 --- a/tests/image/style_transfer/test_model.py +++ b/tests/image/style_transfer/test_model.py @@ -1,10 +1,13 @@ import pytest -from flash.core.utilities.imports import _IMAGE_STLYE_TRANSFER +from flash.core.utilities.imports import _IMAGE_STLYE_TRANSFER, _PYSTICHE_GREATER_EQUAL_0_7_2 from flash.image.style_transfer import StyleTransfer -@pytest.mark.skipif(not _IMAGE_STLYE_TRANSFER, reason="image style transfer libraries aren't installed.") +@pytest.mark.skipif( + not (_IMAGE_STLYE_TRANSFER and _PYSTICHE_GREATER_EQUAL_0_7_2), + reason="image style transfer libraries aren't installed." +) def test_style_transfer_task(): model = StyleTransfer( From b2b313208fc33e2446c9dcbb038996d607398c47 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 17 May 2021 16:11:48 +0100 Subject: [PATCH 48/60] update --- flash/image/style_transfer/backbone.py | 14 ++++++---- flash/image/style_transfer/data.py | 27 ++++++++++++++----- .../datatype_image_style_transfer.txt | 2 +- 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/flash/image/style_transfer/backbone.py b/flash/image/style_transfer/backbone.py index 84dc723bdb..021c4a3ec7 100644 --- a/flash/image/style_transfer/backbone.py +++ b/flash/image/style_transfer/backbone.py @@ -1,3 +1,5 @@ +import re + from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _PYSTICHE_AVAILABLE @@ -9,16 +11,18 @@ from pystiche import enc - for mle_fn in dir(enc): + MLE_FN_PATTERN = re.compile(r"^(?P\w+?)_multi_layer_encoder$") - if "multi_layer_encoder" not in mle_fn: - continue + STYLE_TRANSFER_BACKBONES = FlashRegistry("backbones") - name = mle_fn.split("_")[0] + for mle_fn in dir(enc): + match = MLE_FN_PATTERN.match(mle_fn) + if not match: + continue STYLE_TRANSFER_BACKBONES( fn=lambda: (getattr(enc, mle_fn)(), None), - name=mle_fn.split("_")[0], + name=match.group("name"), namespace="image/style_transfer", package="pystiche", ) diff --git a/flash/image/style_transfer/data.py b/flash/image/style_transfer/data.py index dc1f975cba..ce05069b33 100644 --- a/flash/image/style_transfer/data.py +++ b/flash/image/style_transfer/data.py @@ -8,7 +8,8 @@ from flash.core.data.process import Preprocess from flash.core.data.transforms import ApplyToKeys from flash.core.utilities.imports import _TORCHVISION_AVAILABLE -from flash.image.classification import ImageClassificationData, ImageClassificationPreprocess +from flash.image.classification import ImageClassificationData +from flash.image.data import ImageNumpyDataSource, ImagePathsDataSource, ImageTensorDataSource from flash.image.style_transfer.utils import raise_not_supported if _TORCHVISION_AVAILABLE: @@ -31,7 +32,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Optional[Dict[str, ApplyToKeys]]: return wrapper -class StyleTransferPreprocess(ImageClassificationPreprocess): +class StyleTransferPreprocess(Preprocess): def __init__( self, @@ -48,14 +49,29 @@ def __init__( if isinstance(image_size, int): image_size = (image_size, image_size) + super().__init__( train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - image_size=image_size, + data_sources={ + DefaultDataSources.FILES: ImagePathsDataSource(), + DefaultDataSources.FOLDERS: ImagePathsDataSource(), + DefaultDataSources.NUMPY: ImageNumpyDataSource(), + DefaultDataSources.TENSORS: ImageTensorDataSource(), + DefaultDataSources.TENSORS: ImageTensorDataSource(), + }, + default_data_source=DefaultDataSources.FILES, ) + def get_state_dict(self) -> Dict[str, Any]: + return {**self.transforms, "image_size": self.image_size} + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): + return cls(**state_dict) + @functools.partial(_apply_to_input, keys=DefaultDataKeys.INPUT) def default_transforms(self) -> Optional[Dict[str, Callable]]: if self.training: @@ -71,9 +87,8 @@ def default_transforms(self) -> Optional[Dict[str, Callable]]: pre_tensor_transform=T.Resize(self.image_size), to_tensor_transform=T.ToTensor(), ) - else: - # Style transfer doesn't support a validation or test phase, so we return nothing here - return None + # Style transfer doesn't support a validation or test phase, so we return nothing here + return None class StyleTransferData(ImageClassificationData): diff --git a/requirements/datatype_image_style_transfer.txt b/requirements/datatype_image_style_transfer.txt index 431e475fe2..f0d78a115f 100644 --- a/requirements/datatype_image_style_transfer.txt +++ b/requirements/datatype_image_style_transfer.txt @@ -1,4 +1,4 @@ torchvision Pillow>=7.2 kornia>=0.5.1 -pystiche +pystiche>=0.7.2 From fa7c30403e9bf2dbe65c1a28320c3e5d6cc84da8 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 17 May 2021 16:18:39 +0100 Subject: [PATCH 49/60] change skipif --- tests/image/style_transfer/test_model.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/image/style_transfer/test_model.py b/tests/image/style_transfer/test_model.py index ca00d3290f..fbcdd6c7ad 100644 --- a/tests/image/style_transfer/test_model.py +++ b/tests/image/style_transfer/test_model.py @@ -4,10 +4,7 @@ from flash.image.style_transfer import StyleTransfer -@pytest.mark.skipif( - not (_IMAGE_STLYE_TRANSFER and _PYSTICHE_GREATER_EQUAL_0_7_2), - reason="image style transfer libraries aren't installed." -) +@pytest.mark.skipif(not _PYSTICHE_GREATER_EQUAL_0_7_2, reason="image style transfer libraries aren't installed.") def test_style_transfer_task(): model = StyleTransfer( From ed1574f61e06ffc748d416623564b7bf3cd1a39f Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 17 May 2021 16:24:01 +0100 Subject: [PATCH 50/60] update --- tests/examples/test_scripts.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index 24d93621d9..901b874c85 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -22,7 +22,6 @@ from flash.core.utilities.imports import ( _IMAGE_AVAILABLE, - _IMAGE_STLYE_TRANSFER, _PYSTICHE_GREATER_EQUAL_0_7_2, _TABULAR_AVAILABLE, _TEXT_AVAILABLE, @@ -31,7 +30,6 @@ ) _IMAGE_AVAILABLE = _IMAGE_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_9 -_IMAGE_STLYE_TRANSFER = _IMAGE_STLYE_TRANSFER and _PYSTICHE_GREATER_EQUAL_0_7_2 root = Path(__file__).parent.parent.parent @@ -100,7 +98,7 @@ def run_test(filepath): pytest.param( "finetuning", "style_transfer.py", - marks=pytest.mark.skipif(not _IMAGE_STLYE_TRANSFER, reason="pystiche is not installed") + marks=pytest.mark.skipif(not _PYSTICHE_GREATER_EQUAL_0_7_2, reason="pystiche is not installed") ), pytest.param( "predict", From f1737886d641d73f6113461288f7607c8bbd7167 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 17 May 2021 16:26:11 +0100 Subject: [PATCH 51/60] update --- requirements/datatype_image_style_transfer.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/datatype_image_style_transfer.txt b/requirements/datatype_image_style_transfer.txt index f0d78a115f..4b0fd95c2e 100644 --- a/requirements/datatype_image_style_transfer.txt +++ b/requirements/datatype_image_style_transfer.txt @@ -1,4 +1,4 @@ torchvision Pillow>=7.2 kornia>=0.5.1 -pystiche>=0.7.2 +pystiche>=0.7.1 From 472cb922459a849519d055be4ca93db677ec29ca Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 17 May 2021 16:27:33 +0100 Subject: [PATCH 52/60] update --- docs/source/reference/style_transfer.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/reference/style_transfer.rst b/docs/source/reference/style_transfer.rst index 5e5b147421..a8d78f0dd6 100644 --- a/docs/source/reference/style_transfer.rst +++ b/docs/source/reference/style_transfer.rst @@ -21,7 +21,7 @@ Fit First, you would have to import the :class:`~flash.image.style_transfer.StyleTransfer` and :class:`~flash.image.style_transfer.StyleTransferData` from Flash. -.. testcode:: style_transfer +.. codeblock:: style_transfer import sys import flash @@ -32,7 +32,7 @@ and :class:`~flash.image.style_transfer.StyleTransferData` from Flash. Then, download some content images and create a :class:`~flash.image.style_transfer.StyleTransferData` DataModule. -.. testcode:: style_transfer +.. codeblock:: style_transfer download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") @@ -41,7 +41,7 @@ Then, download some content images and create a :class:`~flash.image.style_trans Select a style image and pass it to the `StyleTransfer` task. -.. testcode:: style_transfer +.. codeblock:: style_transfer style_image = pystiche.demo.images()["paint"].read(size=256) @@ -49,7 +49,7 @@ Select a style image and pass it to the `StyleTransfer` task. Finally, create a Flash :class:`flash.core.trainer.Trainer` and pass it the model and datamodule. -.. testcode:: style_transfer +.. codeblock:: style_transfer trainer = flash.Trainer(max_epochs=2) trainer.fit(model, data_module) From d2ab92844106ee2862032b5c6a9a981db6ac55dc Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 17 May 2021 20:48:57 +0200 Subject: [PATCH 53/60] fix image size for preprocess --- flash/image/style_transfer/data.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/flash/image/style_transfer/data.py b/flash/image/style_transfer/data.py index ce05069b33..822adf78a1 100644 --- a/flash/image/style_transfer/data.py +++ b/flash/image/style_transfer/data.py @@ -40,16 +40,14 @@ def __init__( val_transform: Optional[Union[Dict[str, Callable]]] = None, test_transform: Optional[Union[Dict[str, Callable]]] = None, predict_transform: Optional[Union[Dict[str, Callable]]] = None, - image_size: Union[int, Tuple[int, int]] = 256, + image_size: int = 256, ): if val_transform: raise_not_supported("validation") if test_transform: raise_not_supported("test") - if isinstance(image_size, int): - image_size = (image_size, image_size) - + self.image_size = image_size super().__init__( train_transform=train_transform, val_transform=val_transform, From 1a7819c6c1768f27db6ec7df2ae6375050da7f79 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 17 May 2021 21:18:52 +0200 Subject: [PATCH 54/60] fix style transfer requirements --- requirements/datatype_image_style_transfer.txt | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/requirements/datatype_image_style_transfer.txt b/requirements/datatype_image_style_transfer.txt index 4b0fd95c2e..e536cae01f 100644 --- a/requirements/datatype_image_style_transfer.txt +++ b/requirements/datatype_image_style_transfer.txt @@ -1,4 +1 @@ -torchvision -Pillow>=7.2 -kornia>=0.5.1 -pystiche>=0.7.1 +pystiche>=0.7.2 From a4c7f0a85a3cf88f5fa419bf0aec9257782ad36a Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 17 May 2021 20:23:35 +0100 Subject: [PATCH 55/60] update --- flash/image/style_transfer/data.py | 2 ++ flash_examples/predict/style_transfer.py | 2 +- requirements/datatype_image_style_transfer.txt | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/flash/image/style_transfer/data.py b/flash/image/style_transfer/data.py index ce05069b33..c0ef7b36bb 100644 --- a/flash/image/style_transfer/data.py +++ b/flash/image/style_transfer/data.py @@ -50,6 +50,8 @@ def __init__( if isinstance(image_size, int): image_size = (image_size, image_size) + self.image_size = image_size + super().__init__( train_transform=train_transform, val_transform=val_transform, diff --git a/flash_examples/predict/style_transfer.py b/flash_examples/predict/style_transfer.py index afd9741a29..b6a2cebda5 100644 --- a/flash_examples/predict/style_transfer.py +++ b/flash_examples/predict/style_transfer.py @@ -13,7 +13,7 @@ model = StyleTransfer.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/style_transfer_model.pt") -datamodule = StyleTransferData.from_folders(predict_folder="data/coco128/images", batch_size=4) +datamodule = StyleTransferData.from_folders(predict_folder="data/coco128/images/train2017", batch_size=4) trainer = flash.Trainer(max_epochs=2) trainer.predict(model, datamodule=datamodule) diff --git a/requirements/datatype_image_style_transfer.txt b/requirements/datatype_image_style_transfer.txt index 4b0fd95c2e..e90ea0def4 100644 --- a/requirements/datatype_image_style_transfer.txt +++ b/requirements/datatype_image_style_transfer.txt @@ -1,4 +1,4 @@ torchvision Pillow>=7.2 kornia>=0.5.1 -pystiche>=0.7.1 +pystiche>=0.7.2.post0 From 559446b52391f662d5cca11a6cdced425d03a2a6 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 17 May 2021 20:25:04 +0100 Subject: [PATCH 56/60] update doc --- docs/source/reference/style_transfer.rst | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/docs/source/reference/style_transfer.rst b/docs/source/reference/style_transfer.rst index a8d78f0dd6..fbc2ef722b 100644 --- a/docs/source/reference/style_transfer.rst +++ b/docs/source/reference/style_transfer.rst @@ -21,7 +21,7 @@ Fit First, you would have to import the :class:`~flash.image.style_transfer.StyleTransfer` and :class:`~flash.image.style_transfer.StyleTransferData` from Flash. -.. codeblock:: style_transfer +.. testcode:: style_transfer import sys import flash @@ -32,7 +32,7 @@ and :class:`~flash.image.style_transfer.StyleTransferData` from Flash. Then, download some content images and create a :class:`~flash.image.style_transfer.StyleTransferData` DataModule. -.. codeblock:: style_transfer +.. testcode:: style_transfer download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") @@ -41,7 +41,7 @@ Then, download some content images and create a :class:`~flash.image.style_trans Select a style image and pass it to the `StyleTransfer` task. -.. codeblock:: style_transfer +.. testcode:: style_transfer style_image = pystiche.demo.images()["paint"].read(size=256) @@ -49,11 +49,16 @@ Select a style image and pass it to the `StyleTransfer` task. Finally, create a Flash :class:`flash.core.trainer.Trainer` and pass it the model and datamodule. -.. codeblock:: style_transfer +.. testcode:: style_transfer trainer = flash.Trainer(max_epochs=2) trainer.fit(model, data_module) +.. testoutput:: + :hide: + + ... + ------ From a3b95d535ba0f9bc2f1149d6bfd24c4b2bf92b80 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 17 May 2021 21:27:06 +0200 Subject: [PATCH 57/60] fix style transfer requirements --- requirements/datatype_image_style_transfer.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/datatype_image_style_transfer.txt b/requirements/datatype_image_style_transfer.txt index 6681e7f3e0..e536cae01f 100644 --- a/requirements/datatype_image_style_transfer.txt +++ b/requirements/datatype_image_style_transfer.txt @@ -1 +1 @@ -pystiche>=0.7.2.post0 +pystiche>=0.7.2 From b8d93beb685a2463cb9ee8427681d6d78a04d35a Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 17 May 2021 20:42:40 +0100 Subject: [PATCH 58/60] update --- flash_examples/predict/style_transfer.py | 36 ++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/flash_examples/predict/style_transfer.py b/flash_examples/predict/style_transfer.py index b6a2cebda5..bd8eb6041f 100644 --- a/flash_examples/predict/style_transfer.py +++ b/flash_examples/predict/style_transfer.py @@ -1,19 +1,49 @@ import sys +import numpy as np +import torch +from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter + import flash from flash.core.data.utils import download_data -from flash.core.utilities.imports import _PYSTICHE_AVAILABLE +from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, _PYSTICHE_AVAILABLE from flash.image.style_transfer import StyleTransfer, StyleTransferData if not _PYSTICHE_AVAILABLE: print("Please, run `pip install pystiche`") sys.exit(1) + +class StyleTransferWriter(BasePredictionWriter): + + def __init__(self) -> None: + super().__init__("batch") + + def write_on_batch_end( + self, trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx + ) -> None: + """ + Implement the logic to save a given batch of predictions. + torch.save({"preds": prediction, "batch_indices": batch_indices}, "prediction_{batch_idx}.pt") + """ + + download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") model = StyleTransfer.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/style_transfer_model.pt") datamodule = StyleTransferData.from_folders(predict_folder="data/coco128/images/train2017", batch_size=4) -trainer = flash.Trainer(max_epochs=2) -trainer.predict(model, datamodule=datamodule) +trainer = flash.Trainer(max_epochs=2, callbacks=StyleTransferWriter(), limit_predict_batches=1) +predictions = trainer.predict(model, datamodule=datamodule) + +# display the first stylized image. +image_prediction = torch.stack(predictions[0])[0].numpy() + +if _MATPLOTLIB_AVAILABLE and not flash._IS_TESTING: + import matplotlib.pyplot as plt + image = np.moveaxis(image_prediction, 0, 2) + image -= image.min() + image /= image.max() + plt.imshow(image) + plt.show() From 4fbf11c97e86c8a3a23dd7856d41cce5d0ae464a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 17 May 2021 21:48:11 +0200 Subject: [PATCH 59/60] add reference to pystiche --- docs/source/reference/style_transfer.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/reference/style_transfer.rst b/docs/source/reference/style_transfer.rst index fbc2ef722b..86b2169c04 100644 --- a/docs/source/reference/style_transfer.rst +++ b/docs/source/reference/style_transfer.rst @@ -12,6 +12,10 @@ The goal is that the output image looks like the content image, but “painted .. image:: https://raw.githubusercontent.com/pystiche/pystiche/master/docs/source/graphics/banner/banner.jpg :alt: style_transfer_example +Lightning Flash :class:`~flash.image.style_transfer.StyleTransfer` and +:class:`~flash.image.style_transfer.StyleTransferData` internally rely on `pystiche `_ as +backend. + ------ *** From 31efb530d59e7b14d5b54284d70e7c3e2e20cbf3 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 17 May 2021 21:48:27 +0200 Subject: [PATCH 60/60] remove unnecessary import --- docs/source/reference/style_transfer.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/source/reference/style_transfer.rst b/docs/source/reference/style_transfer.rst index 86b2169c04..87495070cf 100644 --- a/docs/source/reference/style_transfer.rst +++ b/docs/source/reference/style_transfer.rst @@ -27,7 +27,6 @@ and :class:`~flash.image.style_transfer.StyleTransferData` from Flash. .. testcode:: style_transfer - import sys import flash from flash.core.data.utils import download_data from flash.image.style_transfer import StyleTransfer, StyleTransferData