Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/transformers/models/omdet_turbo/configuration_omdet_turbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ def __init__(
**kwargs,
)

# Extract timm.create_model kwargs; TimmBackbone doesn't forward arbitrary config attrs to timm
timm_kwargs = {}
if getattr(backbone_config, "model_type", None) == "timm_backbone":
for attr in ("img_size", "always_partition"):
if hasattr(backbone_config, attr):
timm_kwargs[attr] = getattr(backbone_config, attr)

Comment on lines +165 to +170
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add a comment on why this is needed pls

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added :)

if text_config is None:
logger.info("`text_config` is `None`. Initializing the config with the default `clip_text_model`")
text_config = CONFIG_MAPPING["clip_text_model"]()
Expand Down Expand Up @@ -212,8 +219,14 @@ def __init__(
self.eval_size = eval_size
self.learn_initial_query = learn_initial_query
self.cache_size = cache_size
self.timm_kwargs = timm_kwargs

super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)

def to_dict(self):
output = super().to_dict()
output.pop("timm_kwargs", None)
return output


__all__ = ["OmDetTurboConfig"]
5 changes: 2 additions & 3 deletions src/transformers/models/omdet_turbo/modeling_omdet_turbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

from ... import initialization as init
from ...activations import ACT2CLS, ACT2FN
from ...backbone_utils import load_backbone
from ...integrations import use_kernel_forward_from_hub
from ...masking_utils import create_bidirectional_mask
from ...modeling_layers import GradientCheckpointingLayer
Expand All @@ -39,7 +38,7 @@
logging,
torch_compilable_check,
)
from ..auto import AutoModel
from ..auto import AutoBackbone, AutoModel
from .configuration_omdet_turbo import OmDetTurboConfig


Expand Down Expand Up @@ -279,7 +278,7 @@ class OmDetTurboVisionBackbone(nn.Module):
def __init__(self, config: OmDetTurboConfig):
super().__init__()
self.apply_layernorm_after_vision_backbone = config.apply_layernorm_after_vision_backbone
self.vision_backbone = load_backbone(config)
self.vision_backbone = AutoBackbone.from_config(config.backbone_config, **getattr(config, "timm_kwargs", {}))
self.layer_norms = nn.ModuleList(
[nn.LayerNorm(in_channel_dim, eps=config.layer_norm_eps) for in_channel_dim in config.encoder_in_channels]
)
Expand Down
Loading