diff --git a/src/transformers/models/bit/configuration_bit.py b/src/transformers/models/bit/configuration_bit.py index 6418549ab876..7c1e105107e3 100644 --- a/src/transformers/models/bit/configuration_bit.py +++ b/src/transformers/models/bit/configuration_bit.py @@ -63,7 +63,7 @@ class BitConfig(PretrainedConfig): The width factor for the model. out_features (`List[str]`, *optional*): If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. - (depending on how many stages the model has). + (depending on how many stages the model has). Will default to the last stage if unset. Example: ```python diff --git a/src/transformers/models/bit/modeling_bit.py b/src/transformers/models/bit/modeling_bit.py index df85add18c22..d40e7f4e8c02 100644 --- a/src/transformers/models/bit/modeling_bit.py +++ b/src/transformers/models/bit/modeling_bit.py @@ -851,7 +851,7 @@ def __init__(self, config): self.stage_names = config.stage_names self.bit = BitModel(config) - self.out_features = config.out_features + self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] out_feature_channels = {} out_feature_channels["stem"] = config.embedding_size diff --git a/src/transformers/models/maskformer/configuration_maskformer_swin.py b/src/transformers/models/maskformer/configuration_maskformer_swin.py index 4c9f1a4ca4df..36e0746552c8 100644 --- a/src/transformers/models/maskformer/configuration_maskformer_swin.py +++ b/src/transformers/models/maskformer/configuration_maskformer_swin.py @@ -69,7 +69,8 @@ class MaskFormerSwinConfig(PretrainedConfig): layer_norm_eps (`float`, *optional*, defaults to 1e-12): The epsilon used by the layer normalization layers. out_features (`List[str]`, *optional*): - If used as a backbone, list of feature names to output, e.g. `["stage1", "stage2"]`. + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). Will default to the last stage if unset. Example: diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index 60410a36210d..09bdd817dafd 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -855,7 +855,7 @@ def __init__(self, config: MaskFormerSwinConfig): self.stage_names = config.stage_names self.model = MaskFormerSwinModel(config) - self.out_features = config.out_features + self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] if "stem" in self.out_features: raise ValueError("This backbone does not support 'stem' in the `out_features`.") diff --git a/src/transformers/models/resnet/configuration_resnet.py b/src/transformers/models/resnet/configuration_resnet.py index 2d0dbc3b0fdb..74f6c6939722 100644 --- a/src/transformers/models/resnet/configuration_resnet.py +++ b/src/transformers/models/resnet/configuration_resnet.py @@ -59,8 +59,8 @@ class ResNetConfig(PretrainedConfig): downsample_in_first_stage (`bool`, *optional*, defaults to `False`): If `True`, the first stage will downsample the inputs using a `stride` of 2. out_features (`List[str]`, *optional*): - If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, - `"stage3"`, `"stage4"`. + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). Will default to the last stage if unset. Example: ```python diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py index c3d65ddc05e6..9efedd1faa9c 100644 --- a/src/transformers/models/resnet/modeling_resnet.py +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -267,7 +267,7 @@ def _init_weights(self, module): nn.init.constant_(module.bias, 0) def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (ResNetModel, ResNetBackbone)): + if isinstance(module, ResNetEncoder): module.gradient_checkpointing = value @@ -439,7 +439,7 @@ def __init__(self, config): self.embedder = ResNetEmbeddings(config) self.encoder = ResNetEncoder(config) - self.out_features = config.out_features + self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] out_feature_channels = {} out_feature_channels["stem"] = config.embedding_size diff --git a/tests/models/bit/test_modeling_bit.py b/tests/models/bit/test_modeling_bit.py index 0c3bf147c890..7b7e07cb8fb6 100644 --- a/tests/models/bit/test_modeling_bit.py +++ b/tests/models/bit/test_modeling_bit.py @@ -119,7 +119,7 @@ def create_and_check_backbone(self, config, pixel_values, labels): model.eval() result = model(pixel_values) - # verify hidden states + # verify feature maps self.parent.assertEqual(len(result.feature_maps), len(config.out_features)) self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4]) @@ -127,6 +127,21 @@ def create_and_check_backbone(self, config, pixel_values, labels): self.parent.assertEqual(len(model.channels), len(config.out_features)) self.parent.assertListEqual(model.channels, config.hidden_sizes[1:]) + # verify backbone works with out_features=None + config.out_features = None + model = BitBackbone(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + # verify feature maps + self.parent.assertEqual(len(result.feature_maps), 1) + self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[-1], 1, 1]) + + # verify channels + self.parent.assertEqual(len(model.channels), 1) + self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]]) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() config, pixel_values, labels = config_and_inputs diff --git a/tests/models/resnet/test_modeling_resnet.py b/tests/models/resnet/test_modeling_resnet.py index 53777d27c84d..15d3dca3c53f 100644 --- a/tests/models/resnet/test_modeling_resnet.py +++ b/tests/models/resnet/test_modeling_resnet.py @@ -119,7 +119,7 @@ def create_and_check_backbone(self, config, pixel_values, labels): model.eval() result = model(pixel_values) - # verify hidden states + # verify feature maps self.parent.assertEqual(len(result.feature_maps), len(config.out_features)) self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4]) @@ -127,6 +127,21 @@ def create_and_check_backbone(self, config, pixel_values, labels): self.parent.assertEqual(len(model.channels), len(config.out_features)) self.parent.assertListEqual(model.channels, config.hidden_sizes[1:]) + # verify backbone works with out_features=None + config.out_features = None + model = ResNetBackbone(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + # verify feature maps + self.parent.assertEqual(len(result.feature_maps), 1) + self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[-1], 1, 1]) + + # verify channels + self.parent.assertEqual(len(model.channels), 1) + self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]]) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() config, pixel_values, labels = config_and_inputs