From d7ff91acf37f39756297debb40b687aa62cc3223 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 15 Nov 2022 11:49:43 +0100 Subject: [PATCH 01/12] Add ResNetBackbone --- src/transformers/__init__.py | 4 + src/transformers/modeling_outputs.py | 22 +++++ src/transformers/models/auto/__init__.py | 2 + src/transformers/models/auto/modeling_auto.py | 13 +++ src/transformers/models/resnet/__init__.py | 2 + .../models/resnet/configuration_resnet.py | 5 ++ .../models/resnet/modeling_resnet.py | 88 ++++++++++++++++++- src/transformers/models/resnet/test.py | 19 ++++ src/transformers/utils/dummy_pt_objects.py | 14 +++ 9 files changed, 168 insertions(+), 1 deletion(-) create mode 100644 src/transformers/models/resnet/test.py diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 34845dfbda08..8079ea61e302 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -920,6 +920,7 @@ "MODEL_WITH_LM_HEAD_MAPPING", "MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING", "AutoModel", + "AutoBackbone", "AutoModelForAudioClassification", "AutoModelForAudioFrameClassification", "AutoModelForAudioXVector", @@ -1877,6 +1878,7 @@ "ResNetForImageClassification", "ResNetModel", "ResNetPreTrainedModel", + "ResNetBackbone", ] ) _import_structure["models.retribert"].extend( @@ -3946,6 +3948,7 @@ MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING, MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING, + AutoBackbone, AutoModel, AutoModelForAudioClassification, AutoModelForAudioFrameClassification, @@ -4730,6 +4733,7 @@ ) from .models.resnet import ( RESNET_PRETRAINED_MODEL_ARCHIVE_LIST, + ResNetBackbone, ResNetForImageClassification, ResNetModel, ResNetPreTrainedModel, diff --git a/src/transformers/modeling_outputs.py b/src/transformers/modeling_outputs.py index 606476e96456..10451b8dc960 100644 --- a/src/transformers/modeling_outputs.py +++ b/src/transformers/modeling_outputs.py @@ -1263,3 +1263,25 @@ class XVectorOutput(ModelOutput): embeddings: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class BackboneOutput(ModelOutput): + """ + Base class for outputs of backbones. + + Args: + stage_names (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Names of the stages. + hidden_states (`tuple(torch.FloatTensor)` of shape `(batch_size, num_channels, height, width)`): + Hidden states (also called feature maps) of the stages. + channels (`Tuple(int)`): + Number of channels of the stages. + strides (`Tuple(int)`): + Strides of the stages. + """ + + stage_names: Tuple[str] = None + hidden_states: Tuple[torch.FloatTensor] = None + channels: Tuple[int] = None + strides: Tuple[int] = None diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index a1a1356eb006..718f4a221454 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -73,6 +73,7 @@ "MODEL_WITH_LM_HEAD_MAPPING", "MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING", "AutoModel", + "AutoBackbone", "AutoModelForAudioClassification", "AutoModelForAudioFrameClassification", "AutoModelForAudioXVector", @@ -225,6 +226,7 @@ MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING, MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING, + AutoBackbone, AutoModel, AutoModelForAudioClassification, AutoModelForAudioFrameClassification, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index e9c6d822cad3..15afb9660b98 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -836,6 +836,13 @@ ] ) +BACKBONE_MAPPING_NAMES = OrderedDict( + [ + # Backbone mapping + ("resnet", "ResNetBackbone"), + ] +) + MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES) MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES) @@ -903,6 +910,8 @@ ) MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES) +BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, BACKBONE_MAPPING_NAMES) + class AutoModel(_BaseAutoModelClass): _model_mapping = MODEL_MAPPING @@ -1126,6 +1135,10 @@ class AutoModelForAudioXVector(_BaseAutoModelClass): _model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING +class AutoBackbone(_BaseAutoModelClass): + _model_mapping = BACKBONE_MAPPING + + AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector") diff --git a/src/transformers/models/resnet/__init__.py b/src/transformers/models/resnet/__init__.py index f62c2999671d..be110e9c50ab 100644 --- a/src/transformers/models/resnet/__init__.py +++ b/src/transformers/models/resnet/__init__.py @@ -36,6 +36,7 @@ "ResNetForImageClassification", "ResNetModel", "ResNetPreTrainedModel", + "ResNetBackbone", ] try: @@ -63,6 +64,7 @@ else: from .modeling_resnet import ( RESNET_PRETRAINED_MODEL_ARCHIVE_LIST, + ResNetBackbone, ResNetForImageClassification, ResNetModel, ResNetPreTrainedModel, diff --git a/src/transformers/models/resnet/configuration_resnet.py b/src/transformers/models/resnet/configuration_resnet.py index 8a863e6ae278..c7fffe647af4 100644 --- a/src/transformers/models/resnet/configuration_resnet.py +++ b/src/transformers/models/resnet/configuration_resnet.py @@ -58,6 +58,9 @@ class ResNetConfig(PretrainedConfig): are supported. downsample_in_first_stage (`bool`, *optional*, defaults to `False`): If `True`, the first stage will downsample the inputs using a `stride` of 2. + out_features (`List[str]`, *optional*): + If used as backbone, list of features to output. Can be any of "stem", "stage1", "stage2", "stage3", + "stage4", Example: ```python @@ -85,6 +88,7 @@ def __init__( layer_type="bottleneck", hidden_act="relu", downsample_in_first_stage=False, + out_features=None, **kwargs ): super().__init__(**kwargs) @@ -97,6 +101,7 @@ def __init__( self.layer_type = layer_type self.hidden_act = hidden_act self.downsample_in_first_stage = downsample_in_first_stage + self.out_features = out_features class ResNetOnnxConfig(OnnxConfig): diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py index b771aa3e3125..ff3677d369e7 100644 --- a/src/transformers/models/resnet/modeling_resnet.py +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -23,12 +23,19 @@ from ...activations import ACT2FN from ...modeling_outputs import ( + BackboneOutput, BaseModelOutputWithNoAttention, BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention, ) from ...modeling_utils import PreTrainedModel -from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) from .configuration_resnet import ResNetConfig @@ -416,3 +423,82 @@ def forward( return (loss,) + output if loss is not None else output return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states) + + +@add_start_docstrings( + """ + ResNet backbone, to be used with frameworks like DETR and MaskFormer. + """, + RESNET_START_DOCSTRING, +) +class ResNetBackbone(ResNetPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.resnet = ResNetModel(config) + + # note that this assumes there are always 4 stages + self.stage_names = ["stem", "stage1", "stage2", "stage3", "stage4"] + self.out_features = config.out_features + + # TODO calculate strides appropriately + current_stride = self.resnet.embedder.embedder.convolution.stride[0] + self.out_feature_strides = { + "stem": current_stride, + "stage1": current_stride * 2, + "stage2": current_stride * 2, + "stage3": current_stride * 2, + "stage4": current_stride * 2, + } + + self.out_feature_channels = { + "stem": config.embedding_size, + "stage1": config.hidden_sizes[0], + "stage2": config.hidden_sizes[1], + "stage3": config.hidden_sizes[2], + "stage4": config.hidden_sizes[3], + } + + # initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) + def forward(self, pixel_values: Optional[torch.FloatTensor] = None) -> BackboneOutput: + """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50") + >>> model = AutoBackbone.from_pretrained("microsoft/resnet-50") + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + ```""" + outputs = self.resnet(pixel_values, output_hidden_states=True, return_dict=True) + + hidden_states = outputs.hidden_states + + feature_maps = [] + channels = [] + strides = [] + for idx, stage in enumerate(self.stage_names): + if stage in self.out_features: + feature_maps.append(hidden_states[idx]) + channels.append(self.out_feature_channels[stage]) + strides.append(self.out_feature_strides[stage]) + + return BackboneOutput( + stage_names=self.out_features, hidden_states=feature_maps, channels=channels, strides=strides + ) diff --git a/src/transformers/models/resnet/test.py b/src/transformers/models/resnet/test.py new file mode 100644 index 000000000000..0574c88c4c9b --- /dev/null +++ b/src/transformers/models/resnet/test.py @@ -0,0 +1,19 @@ +from transformers import AutoImageProcessor, AutoBackbone +import torch +from PIL import Image +import requests + +url = "http://images.cocodataset.org/val2017/000000039769.jpg" +image = Image.open(requests.get(url, stream=True).raw) + +processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50") +model = AutoBackbone.from_pretrained("microsoft/resnet-50", out_features=["stage1", "stage2", "stage3", "stage4"]) + +inputs = processor(image, return_tensors="pt") + +outputs = model(**inputs) + +print(outputs.keys()) + +for k,v in zip(outputs.stage_names, outputs.hidden_states): + print(k, v.shape) \ No newline at end of file diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index cc93aa76cbcd..c55c7d6995af 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -437,6 +437,13 @@ def load_tf_weights_in_albert(*args, **kwargs): MODEL_WITH_LM_HEAD_MAPPING = None +class AutoBackbone(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class AutoModel(metaclass=DummyObject): _backends = ["torch"] @@ -4523,6 +4530,13 @@ def load_tf_weights_in_rembert(*args, **kwargs): RESNET_PRETRAINED_MODEL_ARCHIVE_LIST = None +class ResNetBackbone(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class ResNetForImageClassification(metaclass=DummyObject): _backends = ["torch"] From 24c7fd2c51ef0521bc25de0929415a053730ebe4 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 15 Nov 2022 12:04:25 +0100 Subject: [PATCH 02/12] Define channels and strides as property --- src/transformers/modeling_outputs.py | 6 ------ .../models/resnet/modeling_resnet.py | 16 +++++++++------- src/transformers/models/resnet/test.py | 13 ++++++++----- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/src/transformers/modeling_outputs.py b/src/transformers/modeling_outputs.py index 10451b8dc960..aafa8036f05c 100644 --- a/src/transformers/modeling_outputs.py +++ b/src/transformers/modeling_outputs.py @@ -1275,13 +1275,7 @@ class BackboneOutput(ModelOutput): Names of the stages. hidden_states (`tuple(torch.FloatTensor)` of shape `(batch_size, num_channels, height, width)`): Hidden states (also called feature maps) of the stages. - channels (`Tuple(int)`): - Number of channels of the stages. - strides (`Tuple(int)`): - Strides of the stages. """ stage_names: Tuple[str] = None hidden_states: Tuple[torch.FloatTensor] = None - channels: Tuple[int] = None - strides: Tuple[int] = None diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py index ff3677d369e7..2367b457ca4d 100644 --- a/src/transformers/models/resnet/modeling_resnet.py +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -462,6 +462,14 @@ def __init__(self, config): # initialize weights and apply final processing self.post_init() + @property + def channels(self): + return [self.out_feature_channels[name] for name in self.out_features] + + @property + def strides(self): + return [self.out_feature_strides[name] for name in self.out_features] + @add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) def forward(self, pixel_values: Optional[torch.FloatTensor] = None) -> BackboneOutput: @@ -491,14 +499,8 @@ def forward(self, pixel_values: Optional[torch.FloatTensor] = None) -> BackboneO hidden_states = outputs.hidden_states feature_maps = [] - channels = [] - strides = [] for idx, stage in enumerate(self.stage_names): if stage in self.out_features: feature_maps.append(hidden_states[idx]) - channels.append(self.out_feature_channels[stage]) - strides.append(self.out_feature_strides[stage]) - return BackboneOutput( - stage_names=self.out_features, hidden_states=feature_maps, channels=channels, strides=strides - ) + return BackboneOutput(stage_names=self.out_features, hidden_states=feature_maps) diff --git a/src/transformers/models/resnet/test.py b/src/transformers/models/resnet/test.py index 0574c88c4c9b..c8177abfc36e 100644 --- a/src/transformers/models/resnet/test.py +++ b/src/transformers/models/resnet/test.py @@ -1,7 +1,8 @@ -from transformers import AutoImageProcessor, AutoBackbone -import torch from PIL import Image + import requests +from transformers import AutoBackbone, AutoImageProcessor + url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) @@ -13,7 +14,9 @@ outputs = model(**inputs) -print(outputs.keys()) +print(model.channels) + +print(model.strides) -for k,v in zip(outputs.stage_names, outputs.hidden_states): - print(k, v.shape) \ No newline at end of file +for k, v in zip(outputs.stage_names, outputs.hidden_states): + print(k, v.shape) From b6eaa63ddd63b28c76e8bbaf8d98313be7095996 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 15 Nov 2022 12:24:28 +0100 Subject: [PATCH 03/12] Remove file --- src/transformers/modeling_outputs.py | 2 +- src/transformers/models/resnet/test.py | 22 ---------------------- 2 files changed, 1 insertion(+), 23 deletions(-) delete mode 100644 src/transformers/models/resnet/test.py diff --git a/src/transformers/modeling_outputs.py b/src/transformers/modeling_outputs.py index aafa8036f05c..d84426854f8f 100644 --- a/src/transformers/modeling_outputs.py +++ b/src/transformers/modeling_outputs.py @@ -1271,7 +1271,7 @@ class BackboneOutput(ModelOutput): Base class for outputs of backbones. Args: - stage_names (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + stage_names (`tuple(str)`): Names of the stages. hidden_states (`tuple(torch.FloatTensor)` of shape `(batch_size, num_channels, height, width)`): Hidden states (also called feature maps) of the stages. diff --git a/src/transformers/models/resnet/test.py b/src/transformers/models/resnet/test.py deleted file mode 100644 index c8177abfc36e..000000000000 --- a/src/transformers/models/resnet/test.py +++ /dev/null @@ -1,22 +0,0 @@ -from PIL import Image - -import requests -from transformers import AutoBackbone, AutoImageProcessor - - -url = "http://images.cocodataset.org/val2017/000000039769.jpg" -image = Image.open(requests.get(url, stream=True).raw) - -processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50") -model = AutoBackbone.from_pretrained("microsoft/resnet-50", out_features=["stage1", "stage2", "stage3", "stage4"]) - -inputs = processor(image, return_tensors="pt") - -outputs = model(**inputs) - -print(model.channels) - -print(model.strides) - -for k, v in zip(outputs.stage_names, outputs.hidden_states): - print(k, v.shape) From 8be9b7c75c095e73fd236771f0c9f71de8ee63a2 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 15 Nov 2022 14:38:49 +0100 Subject: [PATCH 04/12] Add test for backbone --- tests/models/resnet/test_modeling_resnet.py | 23 ++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/tests/models/resnet/test_modeling_resnet.py b/tests/models/resnet/test_modeling_resnet.py index 557883e0b1ba..9d68f50de348 100644 --- a/tests/models/resnet/test_modeling_resnet.py +++ b/tests/models/resnet/test_modeling_resnet.py @@ -30,7 +30,7 @@ import torch from torch import nn - from transformers import ResNetForImageClassification, ResNetModel + from transformers import ResNetBackbone, ResNetForImageClassification, ResNetModel from transformers.models.resnet.modeling_resnet import RESNET_PRETRAINED_MODEL_ARCHIVE_LIST @@ -55,6 +55,7 @@ def __init__( hidden_act="relu", num_labels=3, scope=None, + out_features=["stage1", "stage2", "stage3", "stage4"], ): self.parent = parent self.batch_size = batch_size @@ -69,6 +70,7 @@ def __init__( self.num_labels = num_labels self.scope = scope self.num_stages = len(hidden_sizes) + self.out_features = out_features def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -89,6 +91,7 @@ def get_config(self): depths=self.depths, hidden_act=self.hidden_act, num_labels=self.num_labels, + out_features=self.out_features, ) def create_and_check_model(self, config, pixel_values, labels): @@ -110,6 +113,20 @@ def create_and_check_for_image_classification(self, config, pixel_values, labels result = model(pixel_values, labels=labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + def create_and_check_backbone(self, config, pixel_values, labels): + model = ResNetBackbone(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + # verify hidden states + self.parent.assertEqual(len(result.hidden_states), 4) + self.parent.assertListEqual(result.stage_names, config.out_features) + self.parent.assertListEqual(list(result.hidden_states[0].shape), [3, 10, 8, 8]) + + # verify channels + self.parent.assertListEqual(model.channels, config.hidden_sizes) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() config, pixel_values, labels = config_and_inputs @@ -176,6 +193,10 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_backbone(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_backbone(*config_and_inputs) + def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() From f47b76b2ed526b4f6255c646e1b43f13a3f98a56 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 15 Nov 2022 16:11:55 +0100 Subject: [PATCH 05/12] Update BackboneOutput class --- src/transformers/modeling_outputs.py | 9 +++------ src/transformers/models/resnet/modeling_resnet.py | 2 +- tests/models/resnet/test_modeling_resnet.py | 5 ++--- utils/check_repo.py | 1 + 4 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/transformers/modeling_outputs.py b/src/transformers/modeling_outputs.py index d84426854f8f..e3ff32930e20 100644 --- a/src/transformers/modeling_outputs.py +++ b/src/transformers/modeling_outputs.py @@ -1271,11 +1271,8 @@ class BackboneOutput(ModelOutput): Base class for outputs of backbones. Args: - stage_names (`tuple(str)`): - Names of the stages. - hidden_states (`tuple(torch.FloatTensor)` of shape `(batch_size, num_channels, height, width)`): - Hidden states (also called feature maps) of the stages. + feature_maps (`tuple(torch.FloatTensor)` of shape `(batch_size, num_channels, height, width)`): + Feature maps of the stages. """ - stage_names: Tuple[str] = None - hidden_states: Tuple[torch.FloatTensor] = None + feature_maps: Tuple[torch.FloatTensor] = None diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py index 2367b457ca4d..4e3fa53acc20 100644 --- a/src/transformers/models/resnet/modeling_resnet.py +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -503,4 +503,4 @@ def forward(self, pixel_values: Optional[torch.FloatTensor] = None) -> BackboneO if stage in self.out_features: feature_maps.append(hidden_states[idx]) - return BackboneOutput(stage_names=self.out_features, hidden_states=feature_maps) + return BackboneOutput(feature_maps=feature_maps) diff --git a/tests/models/resnet/test_modeling_resnet.py b/tests/models/resnet/test_modeling_resnet.py index 9d68f50de348..f2e504843890 100644 --- a/tests/models/resnet/test_modeling_resnet.py +++ b/tests/models/resnet/test_modeling_resnet.py @@ -120,9 +120,8 @@ def create_and_check_backbone(self, config, pixel_values, labels): result = model(pixel_values) # verify hidden states - self.parent.assertEqual(len(result.hidden_states), 4) - self.parent.assertListEqual(result.stage_names, config.out_features) - self.parent.assertListEqual(list(result.hidden_states[0].shape), [3, 10, 8, 8]) + self.parent.assertEqual(len(result.feature_maps), len(config.out_features)) + self.parent.assertListEqual(list(result.feature_maps[0].shape), [3, 10, 8, 8]) # verify channels self.parent.assertListEqual(model.channels, config.hidden_sizes) diff --git a/utils/check_repo.py b/utils/check_repo.py index 951da58eee5d..5d99946ab9f3 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -47,6 +47,7 @@ # Being in this list is an exception and should **not** be the rule. IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ # models to ignore for not tested + "ResNetBackbone", # Backbones have their own tests. "CLIPSegDecoder", # Building part of bigger (tested) model. "TableTransformerEncoder", # Building part of bigger (tested) model. "TableTransformerDecoder", # Building part of bigger (tested) model. From cac9d519ce6a172f7f4060ed0b826e4b91e35001 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 15 Nov 2022 16:57:10 +0100 Subject: [PATCH 06/12] Remove strides property --- src/transformers/models/resnet/modeling_resnet.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py index 4e3fa53acc20..334089d3462d 100644 --- a/src/transformers/models/resnet/modeling_resnet.py +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -441,16 +441,6 @@ def __init__(self, config): self.stage_names = ["stem", "stage1", "stage2", "stage3", "stage4"] self.out_features = config.out_features - # TODO calculate strides appropriately - current_stride = self.resnet.embedder.embedder.convolution.stride[0] - self.out_feature_strides = { - "stem": current_stride, - "stage1": current_stride * 2, - "stage2": current_stride * 2, - "stage3": current_stride * 2, - "stage4": current_stride * 2, - } - self.out_feature_channels = { "stem": config.embedding_size, "stage1": config.hidden_sizes[0], @@ -466,10 +456,6 @@ def __init__(self, config): def channels(self): return [self.out_feature_channels[name] for name in self.out_features] - @property - def strides(self): - return [self.out_feature_strides[name] for name in self.out_features] - @add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) def forward(self, pixel_values: Optional[torch.FloatTensor] = None) -> BackboneOutput: From 816141e6f1b678f35365bcfaa405a25ab082597b Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 15 Nov 2022 17:28:05 +0100 Subject: [PATCH 07/12] Fix docstring --- src/transformers/models/resnet/configuration_resnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/resnet/configuration_resnet.py b/src/transformers/models/resnet/configuration_resnet.py index c7fffe647af4..592a53dac13d 100644 --- a/src/transformers/models/resnet/configuration_resnet.py +++ b/src/transformers/models/resnet/configuration_resnet.py @@ -59,8 +59,8 @@ class ResNetConfig(PretrainedConfig): downsample_in_first_stage (`bool`, *optional*, defaults to `False`): If `True`, the first stage will downsample the inputs using a `stride` of 2. out_features (`List[str]`, *optional*): - If used as backbone, list of features to output. Can be any of "stem", "stage1", "stage2", "stage3", - "stage4", + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, + `"stage3"`, `"stage4"`. Example: ```python From 5796f2bb4f680a305a22bbf4d634ba8ba145c143 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 15 Nov 2022 21:00:51 +0100 Subject: [PATCH 08/12] Add backbones to SHOULD_HAVE_THEIR_OWN_PAGE --- utils/check_repo.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/utils/check_repo.py b/utils/check_repo.py index 5d99946ab9f3..e55ad770bd4b 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -669,6 +669,8 @@ def find_all_documented_objects(): "PyTorchBenchmarkArguments", "TensorFlowBenchmark", "TensorFlowBenchmarkArguments", + "ResNetBackbone", + "AutoBackbone", ] From b37334e6fdfc2afa32fcb0278af7c39fa1d2ab84 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 15 Nov 2022 21:06:57 +0100 Subject: [PATCH 09/12] Fix auto mapping name --- src/transformers/models/auto/modeling_auto.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 15afb9660b98..11c841407bdb 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -836,7 +836,7 @@ ] ) -BACKBONE_MAPPING_NAMES = OrderedDict( +MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict( [ # Backbone mapping ("resnet", "ResNetBackbone"), @@ -910,7 +910,7 @@ ) MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES) -BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, BACKBONE_MAPPING_NAMES) +MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES) class AutoModel(_BaseAutoModelClass): @@ -1136,7 +1136,7 @@ class AutoModelForAudioXVector(_BaseAutoModelClass): class AutoBackbone(_BaseAutoModelClass): - _model_mapping = BACKBONE_MAPPING + _model_mapping = MODEL_FOR_BACKBONE_MAPPING AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector") From 9f037a59bf3f09a65150cb1be8a4ae5f7e2f9e93 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Wed, 16 Nov 2022 15:29:55 +0100 Subject: [PATCH 10/12] Add sanity check for out_features --- src/transformers/models/resnet/configuration_resnet.py | 10 ++++++++++ src/transformers/models/resnet/modeling_resnet.py | 3 +-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/resnet/configuration_resnet.py b/src/transformers/models/resnet/configuration_resnet.py index 592a53dac13d..7064a166db5f 100644 --- a/src/transformers/models/resnet/configuration_resnet.py +++ b/src/transformers/models/resnet/configuration_resnet.py @@ -78,6 +78,8 @@ class ResNetConfig(PretrainedConfig): """ model_type = "resnet" layer_types = ["basic", "bottleneck"] + # note that this assumes there are always 4 stages + stage_names = ["stem", "stage1", "stage2", "stage3", "stage4"] def __init__( self, @@ -101,6 +103,14 @@ def __init__( self.layer_type = layer_type self.hidden_act = hidden_act self.downsample_in_first_stage = downsample_in_first_stage + if out_features is not None: + if not isinstance(out_features, list): + raise ValueError("out_features should be a list") + for feature in out_features: + if feature not in self.stage_names: + raise ValueError( + f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}" + ) self.out_features = out_features diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py index 334089d3462d..b9992271efd2 100644 --- a/src/transformers/models/resnet/modeling_resnet.py +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -435,10 +435,9 @@ class ResNetBackbone(ResNetPreTrainedModel): def __init__(self, config): super().__init__(config) + self.stage_names = config.stage_names self.resnet = ResNetModel(config) - # note that this assumes there are always 4 stages - self.stage_names = ["stem", "stage1", "stage2", "stage3", "stage4"] self.out_features = config.out_features self.out_feature_channels = { From 160c86119510fa2bddd2deabec8f42199818b5c1 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Wed, 16 Nov 2022 17:26:03 +0100 Subject: [PATCH 11/12] Set stage names based on depths --- src/transformers/models/resnet/configuration_resnet.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/models/resnet/configuration_resnet.py b/src/transformers/models/resnet/configuration_resnet.py index 7064a166db5f..2d0dbc3b0fdb 100644 --- a/src/transformers/models/resnet/configuration_resnet.py +++ b/src/transformers/models/resnet/configuration_resnet.py @@ -78,8 +78,6 @@ class ResNetConfig(PretrainedConfig): """ model_type = "resnet" layer_types = ["basic", "bottleneck"] - # note that this assumes there are always 4 stages - stage_names = ["stem", "stage1", "stage2", "stage3", "stage4"] def __init__( self, @@ -103,6 +101,7 @@ def __init__( self.layer_type = layer_type self.hidden_act = hidden_act self.downsample_in_first_stage = downsample_in_first_stage + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] if out_features is not None: if not isinstance(out_features, list): raise ValueError("out_features should be a list") From 2a2033b925f4b7da360a79f08d9db0498c113093 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Thu, 17 Nov 2022 15:02:56 +0100 Subject: [PATCH 12/12] Update to tuple --- src/transformers/models/resnet/modeling_resnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py index b9992271efd2..0988e478dd2a 100644 --- a/src/transformers/models/resnet/modeling_resnet.py +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -483,9 +483,9 @@ def forward(self, pixel_values: Optional[torch.FloatTensor] = None) -> BackboneO hidden_states = outputs.hidden_states - feature_maps = [] + feature_maps = () for idx, stage in enumerate(self.stage_names): if stage in self.out_features: - feature_maps.append(hidden_states[idx]) + feature_maps += (hidden_states[idx],) return BackboneOutput(feature_maps=feature_maps)