Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class ViTHybridConfig(PretrainedConfig):
Whether to add a bias to the queries, keys and values.
backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*, defaults to `None`):
The configuration of the backbone in a dictionary or the config object of the backbone.
backbone_featmap_shape (`List[int]`, *optional*, defaults to `[1, 1024, 24, 24]`):
Used only for the `hybrid` embedding type. The shape of the feature maps of the backbone.

Example:

Expand Down Expand Up @@ -103,6 +105,7 @@ def __init__(
image_size=224,
patch_size=1,
num_channels=3,
backbone_featmap_shape=[1, 1024, 24, 24],
qkv_bias=True,
**kwargs
):
Expand All @@ -128,6 +131,7 @@ def __init__(
backbone_config_class = BitConfig
backbone_config = backbone_config_class(**backbone_config)

self.backbone_featmap_shape = backbone_featmap_shape
self.backbone_config = backbone_config
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
Expand Down
9 changes: 4 additions & 5 deletions src/transformers/models/vit_hybrid/modeling_vit_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,10 @@ def __init__(self, config, feature_size=None):
feature_dim = self.backbone.channels[-1]

if feature_size is None:
dummy_image = torch.zeros(1, num_channels, image_size[0], image_size[1])
with torch.no_grad():
feature_map = self.backbone(dummy_image).feature_maps[-1]
Comment on lines -170 to -171
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You predicted correctly the failing test in #20649 (review) @sgugger !

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ah, making my told you so dance right now ;-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤣 🤣

feature_size = feature_map.shape[-2:]
feature_dim = feature_map.shape[1]
feature_map = config.backbone_featmap_shape

feature_size = feature_map[-2:]
feature_dim = feature_map[1]
else:
feature_size = (
feature_size if isinstance(feature_size, collections.abc.Iterable) else (feature_size, feature_size)
Expand Down
32 changes: 31 additions & 1 deletion tests/models/vit_hybrid/test_modeling_vit_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import unittest

from transformers import ViTHybridConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.testing_utils import require_accelerate, require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available

from ...test_configuration_common import ConfigTester
Expand Down Expand Up @@ -57,6 +57,7 @@ def __init__(
attention_probs_dropout_prob=0.1,
type_sequence_label_size=10,
initializer_range=0.02,
backbone_featmap_shape=[1, 16, 4, 4],
scope=None,
):
self.parent = parent
Expand All @@ -76,6 +77,7 @@ def __init__(
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.scope = scope
self.backbone_featmap_shape = backbone_featmap_shape

# in ViT hybrid, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
# the number of patches is based on the feature map of the backbone, which by default uses an output stride
Expand All @@ -95,6 +97,16 @@ def prepare_config_and_inputs(self):
return config, pixel_values, labels

def get_config(self):
backbone_config = {
"global_padding": "same",
"layer_type": "bottleneck",
"depths": [3, 4, 9],
"out_features": ["stage1", "stage2", "stage3"],
"embedding_dynamic_padding": True,
"hidden_sizes": [4, 8, 16, 32],
"num_groups": 2,
}

return ViTHybridConfig(
image_size=self.image_size,
patch_size=self.patch_size,
Expand All @@ -108,6 +120,8 @@ def get_config(self):
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
backbone_featmap_shape=self.backbone_featmap_shape,
backbone_config=backbone_config,
)

def create_and_check_model(self, config, pixel_values, labels):
Expand Down Expand Up @@ -229,3 +243,19 @@ def test_inference_image_classification_head(self):
expected_slice = torch.tensor([-1.9090, -0.4993, -0.2389]).to(torch_device)

self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))

@slow
@require_accelerate
def test_accelerate_inference(self):
feature_extractor = ViTHybridImageProcessor.from_pretrained("google/vit-hybrid-base-bit-384")
model = ViTHybridForImageClassification.from_pretrained("google/vit-hybrid-base-bit-384", device_map="auto")

image = prepare_img()

inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()

self.assertTrue(model.config.id2label[predicted_class_idx], "tabby, tabby cat")