Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/models/rt_detr/configuration_rt_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class RTDetrConfig(PretrainedConfig):
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
freeze_backbone_batch_norm (`bool`, *optional*, defaults to `True`):
Whether to freeze the batch normalization layers in the backbone.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
Expand Down Expand Up @@ -190,6 +192,7 @@ def __init__(
backbone=None,
use_pretrained_backbone=False,
use_timm_backbone=False,
freeze_backbone_batch_norms=True,
backbone_kwargs=None,
# encoder HybridEncoder
encoder_hidden_dim=256,
Expand Down Expand Up @@ -280,6 +283,7 @@ def __init__(
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.freeze_backbone_batch_norms = freeze_backbone_batch_norms
self.backbone_kwargs = backbone_kwargs
# encoder
self.encoder_hidden_dim = encoder_hidden_dim
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/models/rt_detr/modeling_rt_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,9 +559,10 @@ def __init__(self, config):

backbone = load_backbone(config)

# replace batch norm by frozen batch norm
with torch.no_grad():
replace_batch_norm(backbone)
if config.freeze_backbone_batch_norms:
# replace batch norm by frozen batch norm
with torch.no_grad():
replace_batch_norm(backbone)
self.model = backbone
self.intermediate_channel_sizes = self.model.channels

Expand Down