diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e6758c7df38f..13f70063423a 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2078,6 +2078,7 @@ _import_structure["models.swin"].extend( [ "SWIN_PRETRAINED_MODEL_ARCHIVE_LIST", + "SwinBackbone", "SwinForImageClassification", "SwinForMaskedImageModeling", "SwinModel", @@ -5041,6 +5042,7 @@ ) from .models.swin import ( SWIN_PRETRAINED_MODEL_ARCHIVE_LIST, + SwinBackbone, SwinForImageClassification, SwinForMaskedImageModeling, SwinModel, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index c8dcc9aed1e7..b3f751c688bf 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -869,6 +869,7 @@ ("maskformer-swin", "MaskFormerSwinBackbone"), ("nat", "NatBackbone"), ("resnet", "ResNetBackbone"), + ("swin", "SwinBackbone"), ] ) diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index f1f682e9ea46..46b0c54c4cbf 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -523,7 +523,6 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): self.shift_size = shift_size self.window_size = config.window_size self.input_resolution = input_resolution - self.set_shift_and_window_size(input_resolution) self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.attention = DonutSwinAttention(config, dim, num_heads, window_size=self.window_size) self.drop_path = DonutSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() @@ -585,7 +584,9 @@ def forward( shortcut = hidden_states hidden_states = self.layernorm_before(hidden_states) + hidden_states = hidden_states.view(batch_size, height, width, channels) + # pad hidden_states to multiples of window size hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) @@ -677,14 +678,15 @@ def forward( hidden_states = layer_outputs[0] + hidden_states_before_downsampling = hidden_states if self.downsample is not None: height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 output_dimensions = (height, width, height_downsampled, width_downsampled) - hidden_states = self.downsample(layer_outputs[0], input_dimensions) + hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions) else: output_dimensions = (height, width, height, width) - stage_outputs = (hidden_states, output_dimensions) + stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions) if output_attentions: stage_outputs += layer_outputs[1:] @@ -722,9 +724,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, + output_hidden_states_before_downsampling: Optional[bool] = False, return_dict: Optional[bool] = True, ) -> Union[Tuple, DonutSwinEncoderOutput]: - all_input_dimensions = () all_hidden_states = () if output_hidden_states else None all_reshaped_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -755,12 +757,22 @@ def custom_forward(*inputs): layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] - output_dimensions = layer_outputs[1] + hidden_states_before_downsampling = layer_outputs[1] + output_dimensions = layer_outputs[2] input_dimensions = (output_dimensions[-2], output_dimensions[-1]) - all_input_dimensions += (input_dimensions,) - if output_hidden_states: + if output_hidden_states and output_hidden_states_before_downsampling: + batch_size, _, hidden_size = hidden_states_before_downsampling.shape + # rearrange b (h w) c -> b c h w + # here we use the original (not downsampled) height and width + reshaped_hidden_state = hidden_states_before_downsampling.view( + batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size + ) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states_before_downsampling,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + elif output_hidden_states and not output_hidden_states_before_downsampling: batch_size, _, hidden_size = hidden_states.shape # rearrange b (h w) c -> b c h w reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) @@ -769,7 +781,7 @@ def custom_forward(*inputs): all_reshaped_hidden_states += (reshaped_hidden_state,) if output_attentions: - all_self_attentions += layer_outputs[2:] + all_self_attentions += layer_outputs[3:] if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) diff --git a/src/transformers/models/swin/__init__.py b/src/transformers/models/swin/__init__.py index 63809f369bc5..7f883dae388b 100644 --- a/src/transformers/models/swin/__init__.py +++ b/src/transformers/models/swin/__init__.py @@ -36,6 +36,7 @@ "SwinForMaskedImageModeling", "SwinModel", "SwinPreTrainedModel", + "SwinBackbone", ] try: @@ -63,6 +64,7 @@ else: from .modeling_swin import ( SWIN_PRETRAINED_MODEL_ARCHIVE_LIST, + SwinBackbone, SwinForImageClassification, SwinForMaskedImageModeling, SwinModel, diff --git a/src/transformers/models/swin/configuration_swin.py b/src/transformers/models/swin/configuration_swin.py index 6f51e3186f9e..89c9d556b287 100644 --- a/src/transformers/models/swin/configuration_swin.py +++ b/src/transformers/models/swin/configuration_swin.py @@ -83,6 +83,9 @@ class SwinConfig(PretrainedConfig): The epsilon used by the layer normalization layers. encoder_stride (`int`, `optional`, defaults to 32): Factor to increase the spatial resolution by in the decoder head for masked image modeling. + out_features (`List[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). Will default to the last stage if unset. Example: @@ -125,6 +128,7 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-5, encoder_stride=32, + out_features=None, **kwargs ): super().__init__(**kwargs) @@ -151,6 +155,16 @@ def __init__( # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel # this indicates the channel dimension after the last stage of the model self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] + if out_features is not None: + if not isinstance(out_features, list): + raise ValueError("out_features should be a list") + for feature in out_features: + if feature not in self.stage_names: + raise ValueError( + f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}" + ) + self.out_features = out_features class SwinOnnxConfig(OnnxConfig): diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index e1e59ffc77bf..fe46e7f532c3 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -26,7 +26,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_utils import PreTrainedModel +from ...modeling_outputs import BackboneOutput +from ...modeling_utils import BackboneMixin, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...utils import ( ModelOutput, @@ -589,7 +590,6 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): self.shift_size = shift_size self.window_size = config.window_size self.input_resolution = input_resolution - self.set_shift_and_window_size(input_resolution) self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.attention = SwinAttention(config, dim, num_heads, window_size=self.window_size) self.drop_path = SwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() @@ -651,7 +651,9 @@ def forward( shortcut = hidden_states hidden_states = self.layernorm_before(hidden_states) + hidden_states = hidden_states.view(batch_size, height, width, channels) + # pad hidden_states to multiples of window size hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) @@ -742,14 +744,15 @@ def forward( hidden_states = layer_outputs[0] + hidden_states_before_downsampling = hidden_states if self.downsample is not None: height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 output_dimensions = (height, width, height_downsampled, width_downsampled) - hidden_states = self.downsample(layer_outputs[0], input_dimensions) + hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions) else: output_dimensions = (height, width, height, width) - stage_outputs = (hidden_states, output_dimensions) + stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions) if output_attentions: stage_outputs += layer_outputs[1:] @@ -786,9 +789,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, + output_hidden_states_before_downsampling: Optional[bool] = False, return_dict: Optional[bool] = True, ) -> Union[Tuple, SwinEncoderOutput]: - all_input_dimensions = () all_hidden_states = () if output_hidden_states else None all_reshaped_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -819,12 +822,22 @@ def custom_forward(*inputs): layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] - output_dimensions = layer_outputs[1] + hidden_states_before_downsampling = layer_outputs[1] + output_dimensions = layer_outputs[2] input_dimensions = (output_dimensions[-2], output_dimensions[-1]) - all_input_dimensions += (input_dimensions,) - if output_hidden_states: + if output_hidden_states and output_hidden_states_before_downsampling: + batch_size, _, hidden_size = hidden_states_before_downsampling.shape + # rearrange b (h w) c -> b c h w + # here we use the original (not downsampled) height and width + reshaped_hidden_state = hidden_states_before_downsampling.view( + batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size + ) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states_before_downsampling,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + elif output_hidden_states and not output_hidden_states_before_downsampling: batch_size, _, hidden_size = hidden_states.shape # rearrange b (h w) c -> b c h w reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) @@ -833,7 +846,7 @@ def custom_forward(*inputs): all_reshaped_hidden_states += (reshaped_hidden_state,) if output_attentions: - all_self_attentions += layer_outputs[2:] + all_self_attentions += layer_outputs[3:] if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) @@ -1214,3 +1227,118 @@ def forward( attentions=outputs.attentions, reshaped_hidden_states=outputs.reshaped_hidden_states, ) + + +@add_start_docstrings( + """ + Swin backbone, to be used with frameworks like DETR and MaskFormer. + """, + SWIN_START_DOCSTRING, +) +class SwinBackbone(SwinPreTrainedModel, BackboneMixin): + def __init__(self, config: SwinConfig): + super().__init__(config) + + self.stage_names = config.stage_names + + self.embeddings = SwinEmbeddings(config) + self.encoder = SwinEncoder(config, self.embeddings.patch_grid) + + self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] + + num_features = [int(config.embed_dim * 2**i) for i in range(len(config.depths))] + self.out_feature_channels = {} + self.out_feature_channels["stem"] = config.embed_dim + for i, stage in enumerate(self.stage_names[1:]): + self.out_feature_channels[stage] = num_features[i] + + # Add layer norms to hidden states of out_features + hidden_states_norms = dict() + for stage, num_channels in zip(self.out_features, self.channels): + hidden_states_norms[stage] = nn.LayerNorm(num_channels) + self.hidden_states_norms = nn.ModuleDict(hidden_states_norms) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + @property + def channels(self): + return [self.out_feature_channels[name] for name in self.out_features] + + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224") + >>> model = AutoBackbone.from_pretrained( + ... "microsoft/swin-tiny-patch4-window7-224", out_features=["stage1", "stage2", "stage3", "stage4"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 768, 7, 7] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + embedding_output, input_dimensions = self.embeddings(pixel_values) + + outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=None, + output_attentions=output_attentions, + output_hidden_states=True, + output_hidden_states_before_downsampling=True, + return_dict=True, + ) + + hidden_states = outputs.reshaped_hidden_states + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + batch_size, num_channels, height, width = hidden_state.shape + hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() + hidden_state = hidden_state.view(batch_size, height * width, num_channels) + hidden_state = self.hidden_states_norms[stage](hidden_state) + hidden_state = hidden_state.view(batch_size, height, width, num_channels) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + feature_maps += (hidden_state,) + + if not return_dict: + output = (feature_maps,) + if output_hidden_states: + output += (outputs.hidden_states,) + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index 2a405a1b9341..1e839e767f6f 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -817,14 +817,15 @@ def forward( hidden_states = layer_outputs[0] + hidden_states_before_downsampling = hidden_states if self.downsample is not None: height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 output_dimensions = (height, width, height_downsampled, width_downsampled) - hidden_states = self.downsample(layer_outputs[0], input_dimensions) + hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions) else: output_dimensions = (height, width, height, width) - stage_outputs = (hidden_states, output_dimensions) + stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions) if output_attentions: stage_outputs += layer_outputs[1:] @@ -865,9 +866,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, + output_hidden_states_before_downsampling: Optional[bool] = False, return_dict: Optional[bool] = True, ) -> Union[Tuple, Swinv2EncoderOutput]: - all_input_dimensions = () all_hidden_states = () if output_hidden_states else None all_reshaped_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -898,12 +899,22 @@ def custom_forward(*inputs): layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] - output_dimensions = layer_outputs[1] + hidden_states_before_downsampling = layer_outputs[1] + output_dimensions = layer_outputs[2] input_dimensions = (output_dimensions[-2], output_dimensions[-1]) - all_input_dimensions += (input_dimensions,) - if output_hidden_states: + if output_hidden_states and output_hidden_states_before_downsampling: + batch_size, _, hidden_size = hidden_states_before_downsampling.shape + # rearrange b (h w) c -> b c h w + # here we use the original (not downsampled) height and width + reshaped_hidden_state = hidden_states_before_downsampling.view( + batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size + ) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states_before_downsampling,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + elif output_hidden_states and not output_hidden_states_before_downsampling: batch_size, _, hidden_size = hidden_states.shape # rearrange b (h w) c -> b c h w reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) @@ -912,7 +923,7 @@ def custom_forward(*inputs): all_reshaped_hidden_states += (reshaped_hidden_state,) if output_attentions: - all_self_attentions += layer_outputs[2:] + all_self_attentions += layer_outputs[3:] if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index e8fcfa496932..7ca088f5517a 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -5243,6 +5243,13 @@ def __init__(self, *args, **kwargs): SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = None +class SwinBackbone(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class SwinForImageClassification(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/swin/test_modeling_swin.py b/tests/models/swin/test_modeling_swin.py index 9a5541d50911..3d182392cdb0 100644 --- a/tests/models/swin/test_modeling_swin.py +++ b/tests/models/swin/test_modeling_swin.py @@ -30,7 +30,7 @@ import torch from torch import nn - from transformers import SwinForImageClassification, SwinForMaskedImageModeling, SwinModel + from transformers import SwinBackbone, SwinForImageClassification, SwinForMaskedImageModeling, SwinModel from transformers.models.swin.modeling_swin import SWIN_PRETRAINED_MODEL_ARCHIVE_LIST if is_vision_available(): @@ -66,6 +66,7 @@ def __init__( use_labels=True, type_sequence_label_size=10, encoder_stride=8, + out_features=["stage1", "stage2"], ): self.parent = parent self.batch_size = batch_size @@ -91,6 +92,7 @@ def __init__( self.use_labels = use_labels self.type_sequence_label_size = type_sequence_label_size self.encoder_stride = encoder_stride + self.out_features = out_features def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -123,6 +125,7 @@ def get_config(self): layer_norm_eps=self.layer_norm_eps, initializer_range=self.initializer_range, encoder_stride=self.encoder_stride, + out_features=self.out_features, ) def create_and_check_model(self, config, pixel_values, labels): @@ -136,6 +139,33 @@ def create_and_check_model(self, config, pixel_values, labels): self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim)) + def create_and_check_backbone(self, config, pixel_values, labels): + model = SwinBackbone(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + # verify hidden states + self.parent.assertEqual(len(result.feature_maps), len(config.out_features)) + self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], 16, 16]) + + # verify channels + self.parent.assertEqual(len(model.channels), len(config.out_features)) + + # verify backbone works with out_features=None + config.out_features = None + model = SwinBackbone(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + # verify feature maps + self.parent.assertEqual(len(result.feature_maps), 1) + self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[-1], 4, 4]) + + # verify channels + self.parent.assertEqual(len(model.channels), 1) + def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels): model = SwinForMaskedImageModeling(config=config) model.to(torch_device) @@ -190,6 +220,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = ( ( SwinModel, + SwinBackbone, SwinForImageClassification, SwinForMaskedImageModeling, ) @@ -222,6 +253,10 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_backbone(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_backbone(*config_and_inputs) + def test_for_masked_image_modeling(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs) @@ -230,8 +265,12 @@ def test_for_image_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + @unittest.skip(reason="Swin does not use inputs_embeds") def test_inputs_embeds(self): - # Swin does not use inputs_embeds + pass + + @unittest.skip(reason="Swin Transformer does not use feedforward chunking") + def test_feed_forward_chunking(self): pass def test_model_common_attributes(self): @@ -299,11 +338,8 @@ def test_attention_outputs(self): with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - if hasattr(self.model_tester, "num_hidden_states_types"): - added_hidden_states = self.model_tester.num_hidden_states_types - else: - # also another +1 for reshaped_hidden_states - added_hidden_states = 2 + # also another +1 for reshaped_hidden_states + added_hidden_states = 1 if model_class.__name__ == "SwinBackbone" else 2 self.assertEqual(out_len + added_hidden_states, len(outputs)) self_attentions = outputs.attentions @@ -344,17 +380,18 @@ def check_hidden_states_output(self, inputs_dict, config, model_class, image_siz [num_patches, self.model_tester.embed_dim], ) - reshaped_hidden_states = outputs.reshaped_hidden_states - self.assertEqual(len(reshaped_hidden_states), expected_num_layers) + if not model_class.__name__ == "SwinBackbone": + reshaped_hidden_states = outputs.reshaped_hidden_states + self.assertEqual(len(reshaped_hidden_states), expected_num_layers) - batch_size, num_channels, height, width = reshaped_hidden_states[0].shape - reshaped_hidden_states = ( - reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1) - ) - self.assertListEqual( - list(reshaped_hidden_states.shape[-2:]), - [num_patches, self.model_tester.embed_dim], - ) + batch_size, num_channels, height, width = reshaped_hidden_states[0].shape + reshaped_hidden_states = ( + reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1) + ) + self.assertListEqual( + list(reshaped_hidden_states.shape[-2:]), + [num_patches, self.model_tester.embed_dim], + ) def test_hidden_states_output(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/utils/check_repo.py b/utils/check_repo.py index c72c089d7906..a7eb713efbf8 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -681,6 +681,7 @@ def find_all_documented_objects(): "NatBackbone", "MaskFormerSwinConfig", "MaskFormerSwinModel", + "SwinBackbone", ]