Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,7 @@
"MODEL_WITH_LM_HEAD_MAPPING",
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
"AutoModel",
"AutoBackbone",
"AutoModelForAudioClassification",
"AutoModelForAudioFrameClassification",
"AutoModelForAudioXVector",
Expand Down Expand Up @@ -1877,6 +1878,7 @@
"ResNetForImageClassification",
"ResNetModel",
"ResNetPreTrainedModel",
"ResNetBackbone",
]
)
_import_structure["models.retribert"].extend(
Expand Down Expand Up @@ -3946,6 +3948,7 @@
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
AutoBackbone,
AutoModel,
AutoModelForAudioClassification,
AutoModelForAudioFrameClassification,
Expand Down Expand Up @@ -4730,6 +4733,7 @@
)
from .models.resnet import (
RESNET_PRETRAINED_MODEL_ARCHIVE_LIST,
ResNetBackbone,
ResNetForImageClassification,
ResNetModel,
ResNetPreTrainedModel,
Expand Down
13 changes: 13 additions & 0 deletions src/transformers/modeling_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,3 +1263,16 @@ class XVectorOutput(ModelOutput):
embeddings: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None


@dataclass
class BackboneOutput(ModelOutput):
"""
Base class for outputs of backbones.

Args:
feature_maps (`tuple(torch.FloatTensor)` of shape `(batch_size, num_channels, height, width)`):
Feature maps of the stages.
"""

feature_maps: Tuple[torch.FloatTensor] = None
2 changes: 2 additions & 0 deletions src/transformers/models/auto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
"MODEL_WITH_LM_HEAD_MAPPING",
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
"AutoModel",
"AutoBackbone",
"AutoModelForAudioClassification",
"AutoModelForAudioFrameClassification",
"AutoModelForAudioXVector",
Expand Down Expand Up @@ -225,6 +226,7 @@
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
AutoBackbone,
AutoModel,
AutoModelForAudioClassification,
AutoModelForAudioFrameClassification,
Expand Down
13 changes: 13 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,13 @@
]
)

MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
[
# Backbone mapping
("resnet", "ResNetBackbone"),
]
)

MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
Expand Down Expand Up @@ -903,6 +910,8 @@
)
MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES)

MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES)


class AutoModel(_BaseAutoModelClass):
_model_mapping = MODEL_MAPPING
Expand Down Expand Up @@ -1126,6 +1135,10 @@ class AutoModelForAudioXVector(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING


class AutoBackbone(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_BACKBONE_MAPPING


AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector")


Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/resnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"ResNetForImageClassification",
"ResNetModel",
"ResNetPreTrainedModel",
"ResNetBackbone",
]

try:
Expand Down Expand Up @@ -63,6 +64,7 @@
else:
from .modeling_resnet import (
RESNET_PRETRAINED_MODEL_ARCHIVE_LIST,
ResNetBackbone,
ResNetForImageClassification,
ResNetModel,
ResNetPreTrainedModel,
Expand Down
14 changes: 14 additions & 0 deletions src/transformers/models/resnet/configuration_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class ResNetConfig(PretrainedConfig):
are supported.
downsample_in_first_stage (`bool`, *optional*, defaults to `False`):
If `True`, the first stage will downsample the inputs using a `stride` of 2.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`,
`"stage3"`, `"stage4"`.

Example:
```python
Expand Down Expand Up @@ -85,6 +88,7 @@ def __init__(
layer_type="bottleneck",
hidden_act="relu",
downsample_in_first_stage=False,
out_features=None,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -97,6 +101,16 @@ def __init__(
self.layer_type = layer_type
self.hidden_act = hidden_act
self.downsample_in_first_stage = downsample_in_first_stage
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
self.out_features = out_features


class ResNetOnnxConfig(OnnxConfig):
Expand Down
75 changes: 74 additions & 1 deletion src/transformers/models/resnet/modeling_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,19 @@

from ...activations import ACT2FN
from ...modeling_outputs import (
BackboneOutput,
BaseModelOutputWithNoAttention,
BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention,
)
from ...modeling_utils import PreTrainedModel
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_resnet import ResNetConfig


Expand Down Expand Up @@ -416,3 +423,69 @@ def forward(
return (loss,) + output if loss is not None else output

return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)


@add_start_docstrings(
"""
ResNet backbone, to be used with frameworks like DETR and MaskFormer.
""",
RESNET_START_DOCSTRING,
)
class ResNetBackbone(ResNetPreTrainedModel):
def __init__(self, config):
super().__init__(config)

self.stage_names = config.stage_names
self.resnet = ResNetModel(config)

self.out_features = config.out_features

self.out_feature_channels = {
"stem": config.embedding_size,
"stage1": config.hidden_sizes[0],
"stage2": config.hidden_sizes[1],
"stage3": config.hidden_sizes[2],
"stage4": config.hidden_sizes[3],
}

# initialize weights and apply final processing
self.post_init()

@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]

@add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward(self, pixel_values: Optional[torch.FloatTensor] = None) -> BackboneOutput:
"""
Returns:

Examples:

```python
>>> from transformers import AutoImageProcessor, AutoBackbone
>>> import torch
>>> from PIL import Image
>>> import requests

>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
>>> model = AutoBackbone.from_pretrained("microsoft/resnet-50")

>>> inputs = processor(image, return_tensors="pt")

>>> outputs = model(**inputs)
```"""
outputs = self.resnet(pixel_values, output_hidden_states=True, return_dict=True)

hidden_states = outputs.hidden_states

feature_maps = ()
for idx, stage in enumerate(self.stage_names):
if stage in self.out_features:
feature_maps += (hidden_states[idx],)

return BackboneOutput(feature_maps=feature_maps)
14 changes: 14 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,13 @@ def load_tf_weights_in_albert(*args, **kwargs):
MODEL_WITH_LM_HEAD_MAPPING = None


class AutoBackbone(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class AutoModel(metaclass=DummyObject):
_backends = ["torch"]

Expand Down Expand Up @@ -4523,6 +4530,13 @@ def load_tf_weights_in_rembert(*args, **kwargs):
RESNET_PRETRAINED_MODEL_ARCHIVE_LIST = None


class ResNetBackbone(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class ResNetForImageClassification(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
22 changes: 21 additions & 1 deletion tests/models/resnet/test_modeling_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import torch
from torch import nn

from transformers import ResNetForImageClassification, ResNetModel
from transformers import ResNetBackbone, ResNetForImageClassification, ResNetModel
from transformers.models.resnet.modeling_resnet import RESNET_PRETRAINED_MODEL_ARCHIVE_LIST


Expand All @@ -55,6 +55,7 @@ def __init__(
hidden_act="relu",
num_labels=3,
scope=None,
out_features=["stage1", "stage2", "stage3", "stage4"],
):
self.parent = parent
self.batch_size = batch_size
Expand All @@ -69,6 +70,7 @@ def __init__(
self.num_labels = num_labels
self.scope = scope
self.num_stages = len(hidden_sizes)
self.out_features = out_features

def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
Expand All @@ -89,6 +91,7 @@ def get_config(self):
depths=self.depths,
hidden_act=self.hidden_act,
num_labels=self.num_labels,
out_features=self.out_features,
)

def create_and_check_model(self, config, pixel_values, labels):
Expand All @@ -110,6 +113,19 @@ def create_and_check_for_image_classification(self, config, pixel_values, labels
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))

def create_and_check_backbone(self, config, pixel_values, labels):
model = ResNetBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)

# verify hidden states
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [3, 10, 8, 8])

# verify channels
self.parent.assertListEqual(model.channels, config.hidden_sizes)

def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
Expand Down Expand Up @@ -176,6 +192,10 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)

def test_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_backbone(*config_and_inputs)

def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down
3 changes: 3 additions & 0 deletions utils/check_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
# Being in this list is an exception and should **not** be the rule.
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
# models to ignore for not tested
"ResNetBackbone", # Backbones have their own tests.
"CLIPSegDecoder", # Building part of bigger (tested) model.
"TableTransformerEncoder", # Building part of bigger (tested) model.
"TableTransformerDecoder", # Building part of bigger (tested) model.
Expand Down Expand Up @@ -668,6 +669,8 @@ def find_all_documented_objects():
"PyTorchBenchmarkArguments",
"TensorFlowBenchmark",
"TensorFlowBenchmarkArguments",
"ResNetBackbone",
"AutoBackbone",
]


Expand Down