Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/models/bit/configuration_bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class BitConfig(PretrainedConfig):
The width factor for the model.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has).
(depending on how many stages the model has). Defaults to the last stage in case of `None`.

Example:
```python
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/bit/modeling_bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ def __init__(self, config):
self.stage_names = config.stage_names
self.bit = BitModel(config)

self.out_features = config.out_features
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]

out_feature_channels = {}
out_feature_channels["stem"] = config.embedding_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ class MaskFormerSwinConfig(PretrainedConfig):
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
out_features (`List[str]`, *optional*):
If used as a backbone, list of feature names to output, e.g. `["stage1", "stage2"]`.
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). Defaults to the last stage in case of `None`.

Example:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,7 @@ def __init__(self, config: MaskFormerSwinConfig):
self.stage_names = config.stage_names
self.model = MaskFormerSwinModel(config)

self.out_features = config.out_features
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
if "stem" in self.out_features:
raise ValueError("This backbone does not support 'stem' in the `out_features`.")

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/resnet/configuration_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ class ResNetConfig(PretrainedConfig):
downsample_in_first_stage (`bool`, *optional*, defaults to `False`):
If `True`, the first stage will downsample the inputs using a `stride` of 2.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`,
`"stage3"`, `"stage4"`.
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). Defaults to the last stage in case of `None`.

Example:
```python
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/resnet/modeling_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def _init_weights(self, module):
nn.init.constant_(module.bias, 0)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (ResNetModel, ResNetBackbone)):
if isinstance(module, ResNetEncoder):
module.gradient_checkpointing = value


Expand Down Expand Up @@ -439,7 +439,7 @@ def __init__(self, config):
self.embedder = ResNetEmbeddings(config)
self.encoder = ResNetEncoder(config)

self.out_features = config.out_features
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]

out_feature_channels = {}
out_feature_channels["stem"] = config.embedding_size
Expand Down
17 changes: 16 additions & 1 deletion tests/models/bit/test_modeling_bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,29 @@ def create_and_check_backbone(self, config, pixel_values, labels):
model.eval()
result = model(pixel_values)

# verify hidden states
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4])

# verify channels
self.parent.assertEqual(len(model.channels), len(config.out_features))
self.parent.assertListEqual(model.channels, config.hidden_sizes[1:])

# verify backbone works with out_features=None
config.out_features = None
model = BitBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)

# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[-1], 1, 1])

# verify channels
self.parent.assertEqual(len(model.channels), 1)
self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]])

def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
Expand Down
17 changes: 16 additions & 1 deletion tests/models/resnet/test_modeling_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,29 @@ def create_and_check_backbone(self, config, pixel_values, labels):
model.eval()
result = model(pixel_values)

# verify hidden states
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4])

# verify channels
self.parent.assertEqual(len(model.channels), len(config.out_features))
self.parent.assertListEqual(model.channels, config.hidden_sizes[1:])

# verify backbone works with out_features=None
config.out_features = None
model = ResNetBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)

# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[-1], 1, 1])

# verify channels
self.parent.assertEqual(len(model.channels), 1)
self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]])

def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
Expand Down