Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[WIP] add style transfer task with pystiche #262

Merged
merged 72 commits into from
May 17, 2021

Conversation

pmeier
Copy link
Contributor

@pmeier pmeier commented May 5, 2021

What does this PR do?

Add a style transfer task using pystiche as backend.

Note: Change codeblock to test-code when 0.7.2 is out.

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements) Integration with Lightning Flash. pystiche/pystiche#484
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests? [not needed for typos/docs]
  • Did you verify new and existing tests pass locally with your changes?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

PR review

  • Is this pull request ready for review? (if not, please submit in draft mode)

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@pep8speaks
Copy link

pep8speaks commented May 5, 2021

Hello @pmeier! Thanks for updating this PR.

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2021-05-17 19:48:40 UTC

Copy link
Contributor Author

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just the very preliminary state. I've hit several roadblocks that need to be resolved:

  • Models for Neural Style Transfer are trained in an unsupervised manner. Thus, I simply need a dataset that can supply images without labels / annotations.
  • Following from the point above, there is no train / val / test split. The process is done after the training. If there is something like a validation / test it is performed manually by trying a few examples. There is no objective way to put a number on the quality of the stylization.
  • The models used as transformer are not named. I've seen that the model is usually loaded by their name, which is thus not possible. We could fall back to a author / year combination of the paper the architecture was published.

I'll fix the linting errors and update the documentation, tests, and the changelog when the main part is resolved.

flash/vision/style_transfer/data.py Outdated Show resolved Hide resolved
flash/vision/style_transfer/model.py Outdated Show resolved Hide resolved
flash_examples/predict/style_transfer.py Outdated Show resolved Hide resolved
flash_examples/predict/style_transfer.py Outdated Show resolved Hide resolved
@tchaton
Copy link
Contributor

tchaton commented May 6, 2021

Hey @pmeier,

Awesome you started. You need to properly create a task. Here is the pseudo code to get you started.

class StyleTransfer(Task):

    models: FlashRegistry = STYLE_TRANSFER_MODELS

    def __init__(
        self,

        content_image: Union[Image.PIL, str, np.ndarray],
        style_loss: Optional[Callable] = None,
        content_loss: Optional[Callable] = None,
        perceptual_loss: Optional[Callable] = None,

        model: Union[str, Tuple[nn.Module, int]] = "transformer",
        model_kwargs: Optional[Dict] = None,
        optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam,
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
        scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
        scheduler_kwargs: Optional[Dict[str, Any]] = None,
        metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
        learning_rate: float = 1e-3,
        serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
    ):

        if perceptual_loss is None:
            content_loss = content_loss or self.default_content_loss()
            style_loss = style_loss or self.default_style_transfer()
            perceptual_loss = loss.PerceptualLoss(content_loss, style_loss)

        if content_image is not None:
            perceptual_loss.set_content_image(content_image)

        self.perceptual_loss = perceptual_loss

        self.save_hyperparameters()

        if isinstance(model, tuple):
            model = model
        else:
            model = self.models.get(model)(pretrained=pretrained, **model_kwargs)


        super().__init__(
            model=model,
            loss_fn=perceptual_loss,
            optimizer=optimizer,
            optimizer_kwargs=optimizer_kwargs,
            scheduler=scheduler,
            scheduler_kwargs=scheduler_kwargs,
            metrics=metrics,
            learning_rate=learning_rate,
            serializer=serializer,
        )


    def default_content_loss(self):
        multi_layer_encoder = enc.vgg16_multi_layer_encoder()
        content_layer = "relu2_2"
        content_encoder = multi_layer_encoder.extract_encoder(content_layer)
        content_weight = 1e5
        return = ops.FeatureReconstructionOperator(
            content_encoder, score_weight=content_weight
        )


    def default_style_transfer(self):
        class GramOperator(ops.GramOperator):
            def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor:
                repr = super().enc_to_repr(enc)
                num_channels = repr.size()[1]
                return repr / num_channels


        style_layers = ("relu1_2", "relu2_2", "relu3_3", "relu4_3")
        style_weight = 1e10
        return ops.MultiLayerEncodingOperator(
            multi_layer_encoder,
            style_layers,
            lambda encoder, layer_weight: GramOperator(encoder, score_weight=layer_weight),
            layer_weights="sum",
            score_weight=style_weight,
        )

    def forward(self, x):
        # not sure about this part
        self.model(x)
        return self.perceptual_loss(x)



# in finetuning.

content_image = ...
dm = StyleDataModule.from_folder(...)
model = StyleTransfer(content_image=content_image)
trainer = Trainer(...)
trainer.fit(model, dm)

flash/vision/__init__.py Outdated Show resolved Hide resolved
flash/vision/style_transfer/model.py Outdated Show resolved Hide resolved
flash_examples/predict/style_transfer.py Outdated Show resolved Hide resolved
@codecov
Copy link

codecov bot commented May 6, 2021

Codecov Report

Merging #262 (31efb53) into master (c28bafa) will decrease coverage by 0.49%.
The diff coverage is 75.31%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #262      +/-   ##
==========================================
- Coverage   87.54%   87.05%   -0.50%     
==========================================
  Files          73       78       +5     
  Lines        3815     3970     +155     
==========================================
+ Hits         3340     3456     +116     
- Misses        475      514      +39     
Flag Coverage Δ
unittests 87.05% <75.31%> (-0.50%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
flash/core/data/transforms.py 93.44% <25.00%> (-4.81%) ⬇️
flash/image/style_transfer/data.py 53.70% <53.70%> (ø)
flash/image/style_transfer/utils.py 75.00% <75.00%> (ø)
flash/image/style_transfer/model.py 85.71% <85.71%> (ø)
flash/core/data/data_pipeline.py 92.43% <100.00%> (+0.02%) ⬆️
flash/core/data/data_source.py 95.75% <100.00%> (+0.02%) ⬆️
flash/core/utilities/imports.py 89.13% <100.00%> (+0.75%) ⬆️
flash/image/__init__.py 100.00% <100.00%> (ø)
flash/image/style_transfer/__init__.py 100.00% <100.00%> (ø)
flash/image/style_transfer/backbone.py 100.00% <100.00%> (ø)
... and 5 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update c28bafa...31efb53. Read the comment docs.

Copy link
Contributor Author

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apart from my comments / questions below, I'm wondering whether this example should be in predict or finetune.

flash/vision/style_transfer/model.py Outdated Show resolved Hide resolved
flash/vision/style_transfer/model.py Outdated Show resolved Hide resolved
flash/vision/style_transfer/model.py Outdated Show resolved Hide resolved
flash/vision/style_transfer/model.py Outdated Show resolved Hide resolved
flash_examples/predict/style_transfer.py Outdated Show resolved Hide resolved
flash_examples/predict/style_transfer.py Outdated Show resolved Hide resolved
@pmeier pmeier requested review from edgarriba and tchaton May 10, 2021 07:32
@tchaton tchaton marked this pull request as ready for review May 17, 2021 12:58
@tchaton tchaton requested a review from edenlightning as a code owner May 17, 2021 12:58
Copy link
Contributor Author

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @tchaton, thanks for the commits. I have some comments below. Additionally, it looks like you have added a lot of changes that seemingly have nothing to do with this PR. Was that intentional?

flash/core/data/data_pipeline.py Show resolved Hide resolved
flash/image/style_transfer/backbone.py Outdated Show resolved Hide resolved
flash/image/style_transfer/backbone.py Outdated Show resolved Hide resolved
requirements/datatype_image_style_transfer.txt Outdated Show resolved Hide resolved
tests/examples/test_scripts.py Outdated Show resolved Hide resolved
@mergify mergify bot removed the has conflicts label May 17, 2021
Copy link
Contributor

@edgarriba edgarriba left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

docs/source/reference/style_transfer.rst Show resolved Hide resolved
flash/image/style_transfer/model.py Show resolved Hide resolved
flash/image/style_transfer/model.py Show resolved Hide resolved
Copy link
Collaborator

@ethanwharris ethanwharris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Small comment

docs/source/reference/style_transfer.rst Show resolved Hide resolved
@tchaton tchaton merged commit 7c89fc1 into Lightning-Universe:master May 17, 2021
@pmeier pmeier deleted the style-transfer branch May 17, 2021 20:15
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants