Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2078,6 +2078,7 @@
_import_structure["models.swin"].extend(
[
"SWIN_PRETRAINED_MODEL_ARCHIVE_LIST",
"SwinBackbone",
"SwinForImageClassification",
"SwinForMaskedImageModeling",
"SwinModel",
Expand Down Expand Up @@ -5041,6 +5042,7 @@
)
from .models.swin import (
SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
SwinBackbone,
SwinForImageClassification,
SwinForMaskedImageModeling,
SwinModel,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,7 @@
("maskformer-swin", "MaskFormerSwinBackbone"),
("nat", "NatBackbone"),
("resnet", "ResNetBackbone"),
("swin", "SwinBackbone"),
]
)

Expand Down
28 changes: 20 additions & 8 deletions src/transformers/models/donut/modeling_donut_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,6 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
self.shift_size = shift_size
self.window_size = config.window_size
self.input_resolution = input_resolution
self.set_shift_and_window_size(input_resolution)
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.attention = DonutSwinAttention(config, dim, num_heads, window_size=self.window_size)
self.drop_path = DonutSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
Expand Down Expand Up @@ -585,7 +584,9 @@ def forward(
shortcut = hidden_states

hidden_states = self.layernorm_before(hidden_states)

hidden_states = hidden_states.view(batch_size, height, width, channels)

# pad hidden_states to multiples of window size
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)

Expand Down Expand Up @@ -677,14 +678,15 @@ def forward(

hidden_states = layer_outputs[0]

hidden_states_before_downsampling = hidden_states
if self.downsample is not None:
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
output_dimensions = (height, width, height_downsampled, width_downsampled)
hidden_states = self.downsample(layer_outputs[0], input_dimensions)
hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
else:
output_dimensions = (height, width, height, width)

stage_outputs = (hidden_states, output_dimensions)
stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)

if output_attentions:
stage_outputs += layer_outputs[1:]
Expand Down Expand Up @@ -722,9 +724,9 @@ def forward(
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
output_hidden_states_before_downsampling: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, DonutSwinEncoderOutput]:
all_input_dimensions = ()
all_hidden_states = () if output_hidden_states else None
all_reshaped_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
Expand Down Expand Up @@ -755,12 +757,22 @@ def custom_forward(*inputs):
layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)

hidden_states = layer_outputs[0]
output_dimensions = layer_outputs[1]
hidden_states_before_downsampling = layer_outputs[1]
output_dimensions = layer_outputs[2]

input_dimensions = (output_dimensions[-2], output_dimensions[-1])
all_input_dimensions += (input_dimensions,)

if output_hidden_states:
if output_hidden_states and output_hidden_states_before_downsampling:
batch_size, _, hidden_size = hidden_states_before_downsampling.shape
# rearrange b (h w) c -> b c h w
# here we use the original (not downsampled) height and width
reshaped_hidden_state = hidden_states_before_downsampling.view(
batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
)
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
all_hidden_states += (hidden_states_before_downsampling,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
elif output_hidden_states and not output_hidden_states_before_downsampling:
batch_size, _, hidden_size = hidden_states.shape
# rearrange b (h w) c -> b c h w
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
Expand All @@ -769,7 +781,7 @@ def custom_forward(*inputs):
all_reshaped_hidden_states += (reshaped_hidden_state,)

if output_attentions:
all_self_attentions += layer_outputs[2:]
all_self_attentions += layer_outputs[3:]

if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/swin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"SwinForMaskedImageModeling",
"SwinModel",
"SwinPreTrainedModel",
"SwinBackbone",
]

try:
Expand Down Expand Up @@ -63,6 +64,7 @@
else:
from .modeling_swin import (
SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
SwinBackbone,
SwinForImageClassification,
SwinForMaskedImageModeling,
SwinModel,
Expand Down
14 changes: 14 additions & 0 deletions src/transformers/models/swin/configuration_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ class SwinConfig(PretrainedConfig):
The epsilon used by the layer normalization layers.
encoder_stride (`int`, `optional`, defaults to 32):
Factor to increase the spatial resolution by in the decoder head for masked image modeling.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). Will default to the last stage if unset.

Example:

Expand Down Expand Up @@ -125,6 +128,7 @@ def __init__(
initializer_range=0.02,
layer_norm_eps=1e-5,
encoder_stride=32,
out_features=None,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -151,6 +155,16 @@ def __init__(
# we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
# this indicates the channel dimension after the last stage of the model
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
self.out_features = out_features


class SwinOnnxConfig(OnnxConfig):
Expand Down
146 changes: 137 additions & 9 deletions src/transformers/models/swin/modeling_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...modeling_utils import PreTrainedModel
from ...modeling_outputs import BackboneOutput
from ...modeling_utils import BackboneMixin, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from ...utils import (
ModelOutput,
Expand Down Expand Up @@ -589,7 +590,6 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
self.shift_size = shift_size
self.window_size = config.window_size
self.input_resolution = input_resolution
self.set_shift_and_window_size(input_resolution)
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.attention = SwinAttention(config, dim, num_heads, window_size=self.window_size)
self.drop_path = SwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
Expand Down Expand Up @@ -651,7 +651,9 @@ def forward(
shortcut = hidden_states

hidden_states = self.layernorm_before(hidden_states)

hidden_states = hidden_states.view(batch_size, height, width, channels)

# pad hidden_states to multiples of window size
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)

Expand Down Expand Up @@ -742,14 +744,15 @@ def forward(

hidden_states = layer_outputs[0]

hidden_states_before_downsampling = hidden_states
if self.downsample is not None:
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
output_dimensions = (height, width, height_downsampled, width_downsampled)
hidden_states = self.downsample(layer_outputs[0], input_dimensions)
hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
else:
output_dimensions = (height, width, height, width)

stage_outputs = (hidden_states, output_dimensions)
stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)

if output_attentions:
stage_outputs += layer_outputs[1:]
Expand Down Expand Up @@ -786,9 +789,9 @@ def forward(
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
output_hidden_states_before_downsampling: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, SwinEncoderOutput]:
all_input_dimensions = ()
all_hidden_states = () if output_hidden_states else None
all_reshaped_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
Expand Down Expand Up @@ -819,12 +822,22 @@ def custom_forward(*inputs):
layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)

hidden_states = layer_outputs[0]
output_dimensions = layer_outputs[1]
hidden_states_before_downsampling = layer_outputs[1]
output_dimensions = layer_outputs[2]

input_dimensions = (output_dimensions[-2], output_dimensions[-1])
all_input_dimensions += (input_dimensions,)

if output_hidden_states:
if output_hidden_states and output_hidden_states_before_downsampling:
batch_size, _, hidden_size = hidden_states_before_downsampling.shape
# rearrange b (h w) c -> b c h w
# here we use the original (not downsampled) height and width
reshaped_hidden_state = hidden_states_before_downsampling.view(
batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
)
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
all_hidden_states += (hidden_states_before_downsampling,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
elif output_hidden_states and not output_hidden_states_before_downsampling:
batch_size, _, hidden_size = hidden_states.shape
# rearrange b (h w) c -> b c h w
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
Expand All @@ -833,7 +846,7 @@ def custom_forward(*inputs):
all_reshaped_hidden_states += (reshaped_hidden_state,)

if output_attentions:
all_self_attentions += layer_outputs[2:]
all_self_attentions += layer_outputs[3:]

if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
Expand Down Expand Up @@ -1214,3 +1227,118 @@ def forward(
attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states,
)


@add_start_docstrings(
"""
Swin backbone, to be used with frameworks like DETR and MaskFormer.
""",
SWIN_START_DOCSTRING,
)
class SwinBackbone(SwinPreTrainedModel, BackboneMixin):
def __init__(self, config: SwinConfig):
super().__init__(config)

self.stage_names = config.stage_names

self.embeddings = SwinEmbeddings(config)
self.encoder = SwinEncoder(config, self.embeddings.patch_grid)

self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]

num_features = [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.out_feature_channels = {}
self.out_feature_channels["stem"] = config.embed_dim
for i, stage in enumerate(self.stage_names[1:]):
self.out_feature_channels[stage] = num_features[i]

# Add layer norms to hidden states of out_features
hidden_states_norms = dict()
for stage, num_channels in zip(self.out_features, self.channels):
hidden_states_norms[stage] = nn.LayerNorm(num_channels)
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)

# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embeddings.patch_embeddings

@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]

def forward(
self,
pixel_values: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> BackboneOutput:
"""
Returns:

Examples:

```python
>>> from transformers import AutoImageProcessor, AutoBackbone
>>> import torch
>>> from PIL import Image
>>> import requests

>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
>>> model = AutoBackbone.from_pretrained(
... "microsoft/swin-tiny-patch4-window7-224", out_features=["stage1", "stage2", "stage3", "stage4"]
... )

>>> inputs = processor(image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> feature_maps = outputs.feature_maps
>>> list(feature_maps[-1].shape)
[1, 768, 7, 7]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions

embedding_output, input_dimensions = self.embeddings(pixel_values)

outputs = self.encoder(
embedding_output,
input_dimensions,
head_mask=None,
output_attentions=output_attentions,
output_hidden_states=True,
output_hidden_states_before_downsampling=True,
return_dict=True,
)

hidden_states = outputs.reshaped_hidden_states

feature_maps = ()
for stage, hidden_state in zip(self.stage_names, hidden_states):
if stage in self.out_features:
batch_size, num_channels, height, width = hidden_state.shape
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
hidden_state = hidden_state.view(batch_size, height * width, num_channels)
hidden_state = self.hidden_states_norms[stage](hidden_state)
hidden_state = hidden_state.view(batch_size, height, width, num_channels)
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
feature_maps += (hidden_state,)

if not return_dict:
output = (feature_maps,)
if output_hidden_states:
output += (outputs.hidden_states,)
return output

return BackboneOutput(
feature_maps=feature_maps,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
Loading