Skip to content

Conversation

@marvingabler
Copy link
Contributor

What does this PR do?

This PR adds the feature of accepting arbitary number of input and output channels when using the Swin2SR model. This allows to perform super resolution from greyscale (1 channel) to color (rgb), or from low resolution multi band satellite to high resolution rgb satellite.

All examples and pretrained models are running as expected based on my tests. No new dependencies have been added.

Just use it like

from transformers import Swin2SRForImageSuperResolution, Swin2SRConfig
import torch

Swin2SRConfig = (
     num_channels_in=1,
     num_channels_out=3
)
model = Swin2SRForImageSuperResolution(Swin2SRConfig)

with torch.no_grad():
    # or use the image preprocessor per default
    out = model({"pixel_values":torch.randn((1,1,264,264))})

Fixes #26566.

Before submitting

Tagging the reviewers

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @marvingabler
Thanks for your contribution! as you are removing the num_channels attribute I think that this is a breaking change, what about keeping num_channels and make it behave as num_channels_in and use num_channels_out as an optional argument that is initialized as the same value as num_channels in case it is set to None. That way I believe changes will be backward compatible. What do you think?

@marvingabler
Copy link
Contributor Author

Good point, yes lets do that! Let me update the PR soon :)

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great to me! thanks!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Can you add a small test as well? Making sure that a dummy model with a this can still perform as expected! 😉

@marvingabler
Copy link
Contributor Author

Just realized that there are a couple of more changes required, as the Swin2SRForImageSuperResolution denormalizes based on the input images, while for the case of mapping from multiband images to single band, the mean&stds of inputs and outputs differ. Will add the changes soon.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me thanks for adding this.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks!

@LysandreJik LysandreJik merged commit 0a3b9d0 into huggingface:main Oct 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

SWIN2SR: Allow to choose number of in_channels and out_channels

6 participants