diff --git a/src/transformers/models/omdet_turbo/configuration_omdet_turbo.py b/src/transformers/models/omdet_turbo/configuration_omdet_turbo.py index 0494b0d1dc88..794174931202 100644 --- a/src/transformers/models/omdet_turbo/configuration_omdet_turbo.py +++ b/src/transformers/models/omdet_turbo/configuration_omdet_turbo.py @@ -161,6 +161,13 @@ def __init__( **kwargs, ) + # Extract timm.create_model kwargs; TimmBackbone doesn't forward arbitrary config attrs to timm + timm_kwargs = {} + if getattr(backbone_config, "model_type", None) == "timm_backbone": + for attr in ("img_size", "always_partition"): + if hasattr(backbone_config, attr): + timm_kwargs[attr] = getattr(backbone_config, attr) + if text_config is None: logger.info("`text_config` is `None`. Initializing the config with the default `clip_text_model`") text_config = CONFIG_MAPPING["clip_text_model"]() @@ -212,8 +219,14 @@ def __init__( self.eval_size = eval_size self.learn_initial_query = learn_initial_query self.cache_size = cache_size + self.timm_kwargs = timm_kwargs super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) + def to_dict(self): + output = super().to_dict() + output.pop("timm_kwargs", None) + return output + __all__ = ["OmDetTurboConfig"] diff --git a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py index a30b52897051..1819c7aa873d 100644 --- a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py +++ b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py @@ -26,7 +26,6 @@ from ... import initialization as init from ...activations import ACT2CLS, ACT2FN -from ...backbone_utils import load_backbone from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer @@ -39,7 +38,7 @@ logging, torch_compilable_check, ) -from ..auto import AutoModel +from ..auto import AutoBackbone, AutoModel from .configuration_omdet_turbo import OmDetTurboConfig @@ -279,7 +278,7 @@ class OmDetTurboVisionBackbone(nn.Module): def __init__(self, config: OmDetTurboConfig): super().__init__() self.apply_layernorm_after_vision_backbone = config.apply_layernorm_after_vision_backbone - self.vision_backbone = load_backbone(config) + self.vision_backbone = AutoBackbone.from_config(config.backbone_config, **getattr(config, "timm_kwargs", {})) self.layer_norms = nn.ModuleList( [nn.LayerNorm(in_channel_dim, eps=config.layer_norm_eps) for in_channel_dim in config.encoder_in_channels] )