Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,6 +1272,7 @@
_import_structure["models.dinat"].extend(
[
"DINAT_PRETRAINED_MODEL_ARCHIVE_LIST",
"DinatBackbone",
"DinatForImageClassification",
"DinatModel",
"DinatPreTrainedModel",
Expand Down Expand Up @@ -1767,6 +1768,7 @@
_import_structure["models.nat"].extend(
[
"NAT_PRETRAINED_MODEL_ARCHIVE_LIST",
"NatBackbone",
"NatForImageClassification",
"NatModel",
"NatPreTrainedModel",
Expand Down Expand Up @@ -4385,6 +4387,7 @@
)
from .models.dinat import (
DINAT_PRETRAINED_MODEL_ARCHIVE_LIST,
DinatBackbone,
DinatForImageClassification,
DinatModel,
DinatPreTrainedModel,
Expand Down Expand Up @@ -4781,6 +4784,7 @@
)
from .models.nat import (
NAT_PRETRAINED_MODEL_ARCHIVE_LIST,
NatBackbone,
NatForImageClassification,
NatModel,
NatPreTrainedModel,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,9 @@
[
# Backbone mapping
("bit", "BitBackbone"),
("dinat", "DinatBackbone"),
("maskformer-swin", "MaskFormerSwinBackbone"),
("nat", "NatBackbone"),
("resnet", "ResNetBackbone"),
]
)
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/dinat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"DinatForImageClassification",
"DinatModel",
"DinatPreTrainedModel",
"DinatBackbone",
]

if TYPE_CHECKING:
Expand All @@ -48,6 +49,7 @@
else:
from .modeling_dinat import (
DINAT_PRETRAINED_MODEL_ARCHIVE_LIST,
DinatBackbone,
DinatForImageClassification,
DinatModel,
DinatPreTrainedModel,
Expand Down
14 changes: 14 additions & 0 deletions src/transformers/models/dinat/configuration_dinat.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class DinatConfig(PretrainedConfig):
The epsilon used by the layer normalization layers.
layer_scale_init_value (`float`, *optional*, defaults to 0.0):
The initial value for the layer scale. Disabled if <=0.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). Will default to the last stage if unset.

Example:

Expand Down Expand Up @@ -113,6 +116,7 @@ def __init__(
initializer_range=0.02,
layer_norm_eps=1e-5,
layer_scale_init_value=0.0,
out_features=None,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -138,3 +142,13 @@ def __init__(
# this indicates the channel dimension after the last stage of the model
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
self.layer_scale_init_value = layer_scale_init_value
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
self.out_features = out_features
139 changes: 131 additions & 8 deletions src/transformers/models/dinat/modeling_dinat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...modeling_utils import PreTrainedModel
from ...modeling_outputs import BackboneOutput
from ...modeling_utils import BackboneMixin, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
Expand All @@ -35,6 +36,7 @@
add_start_docstrings_to_model_forward,
is_natten_available,
logging,
replace_return_docstrings,
requires_backends,
)
from .configuration_dinat import DinatConfig
Expand Down Expand Up @@ -555,14 +557,11 @@ def forward(
layer_outputs = layer_module(hidden_states, output_attentions)
hidden_states = layer_outputs[0]

hidden_states_before_downsampling = hidden_states
if self.downsample is not None:
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
output_dimensions = (height, width, height_downsampled, width_downsampled)
hidden_states = self.downsample(layer_outputs[0])
else:
output_dimensions = (height, width, height, width)
hidden_states = self.downsample(hidden_states_before_downsampling)

stage_outputs = (hidden_states, output_dimensions)
stage_outputs = (hidden_states, hidden_states_before_downsampling)

if output_attentions:
stage_outputs += layer_outputs[1:]
Expand Down Expand Up @@ -596,6 +595,7 @@ def forward(
hidden_states: torch.Tensor,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
output_hidden_states_before_downsampling: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, DinatEncoderOutput]:
all_hidden_states = () if output_hidden_states else None
Expand All @@ -612,8 +612,14 @@ def forward(
layer_outputs = layer_module(hidden_states, output_attentions)

hidden_states = layer_outputs[0]
hidden_states_before_downsampling = layer_outputs[1]

if output_hidden_states:
if output_hidden_states and output_hidden_states_before_downsampling:
# rearrange b h w c -> b c h w
reshaped_hidden_state = hidden_states_before_downsampling.permute(0, 3, 1, 2)
all_hidden_states += (hidden_states_before_downsampling,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
elif output_hidden_states and not output_hidden_states_before_downsampling:
# rearrange b h w c -> b c h w
reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2)
all_hidden_states += (hidden_states,)
Expand Down Expand Up @@ -871,3 +877,120 @@ def forward(
attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states,
)


@add_start_docstrings(
"NAT backbone, to be used with frameworks like DETR and MaskFormer.",
DINAT_START_DOCSTRING,
)
class DinatBackbone(DinatPreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)

requires_backends(self, ["natten"])

self.stage_names = config.stage_names

self.embeddings = DinatEmbeddings(config)
self.encoder = DinatEncoder(config)

self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]

num_features = [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.out_feature_channels = {}
self.out_feature_channels["stem"] = config.embed_dim
for i, stage in enumerate(self.stage_names[1:]):
self.out_feature_channels[stage] = num_features[i]

# Add layer norms to hidden states of out_features
hidden_states_norms = dict()
for stage, num_channels in zip(self.out_features, self.channels):
hidden_states_norms[stage] = nn.LayerNorm(num_channels)
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)

# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embeddings.patch_embeddings

@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]

@add_start_docstrings_to_model_forward(DINAT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> BackboneOutput:
"""
Returns:

Examples:

```python
>>> from transformers import AutoImageProcessor, AutoBackbone
>>> import torch
>>> from PIL import Image
>>> import requests

>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
>>> model = AutoBackbone.from_pretrained(
... "shi-labs/nat-mini-in1k-2240", out_features=["stage1", "stage2", "stage3", "stage4"]
... )

>>> inputs = processor(image, return_tensors="pt")

>>> outputs = model(**inputs)

>>> feature_maps = outputs.feature_maps
>>> list(feature_maps[-1].shape)
[1, 2048, 7, 7]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions

embedding_output = self.embeddings(pixel_values)

outputs = self.encoder(
embedding_output,
output_attentions=output_attentions,
output_hidden_states=True,
output_hidden_states_before_downsampling=True,
return_dict=True,
)

hidden_states = outputs.reshaped_hidden_states

feature_maps = ()
for stage, hidden_state in zip(self.stage_names, hidden_states):
if stage in self.out_features:
batch_size, num_channels, height, width = hidden_state.shape
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
hidden_state = hidden_state.view(batch_size, height * width, num_channels)
hidden_state = self.hidden_states_norms[stage](hidden_state)
hidden_state = hidden_state.view(batch_size, height, width, num_channels)
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
feature_maps += (hidden_state,)

if not return_dict:
output = (feature_maps,)
if output_hidden_states:
output += (outputs.hidden_states,)
return output

return BackboneOutput(
feature_maps=feature_maps,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
2 changes: 2 additions & 0 deletions src/transformers/models/nat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"NatForImageClassification",
"NatModel",
"NatPreTrainedModel",
"NatBackbone",
]

if TYPE_CHECKING:
Expand All @@ -48,6 +49,7 @@
else:
from .modeling_nat import (
NAT_PRETRAINED_MODEL_ARCHIVE_LIST,
NatBackbone,
NatForImageClassification,
NatModel,
NatPreTrainedModel,
Expand Down
14 changes: 14 additions & 0 deletions src/transformers/models/nat/configuration_nat.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ class NatConfig(PretrainedConfig):
The epsilon used by the layer normalization layers.
layer_scale_init_value (`float`, *optional*, defaults to 0.0):
The initial value for the layer scale. Disabled if <=0.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). Will default to the last stage if unset.

Example:

Expand Down Expand Up @@ -110,6 +113,7 @@ def __init__(
initializer_range=0.02,
layer_norm_eps=1e-5,
layer_scale_init_value=0.0,
out_features=None,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -134,3 +138,13 @@ def __init__(
# this indicates the channel dimension after the last stage of the model
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
self.layer_scale_init_value = layer_scale_init_value
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
self.out_features = out_features
Loading