diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index fd96aacabc2a..265b79dca13a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -15,6 +15,7 @@ # limitations under the License. import collections import gc +import inspect import json import os import re @@ -932,6 +933,15 @@ def floating_point_ops( return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) +class BackboneMixin: + def forward_with_filtered_kwargs(self, *args, **kwargs): + + signature = dict(inspect.signature(self.forward).parameters) + filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature} + + return self(*args, **filtered_kwargs) + + class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin): r""" Base class for all models. diff --git a/src/transformers/models/bit/modeling_bit.py b/src/transformers/models/bit/modeling_bit.py index ac88a704b4c4..3704b1d9872b 100644 --- a/src/transformers/models/bit/modeling_bit.py +++ b/src/transformers/models/bit/modeling_bit.py @@ -30,7 +30,7 @@ BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import BackboneMixin, PreTrainedModel from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -842,7 +842,7 @@ def forward( """, BIT_START_DOCSTRING, ) -class BitBackbone(BitPreTrainedModel): +class BitBackbone(BitPreTrainedModel, BackboneMixin): def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index 60410a36210d..05af1b90ff1b 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -27,7 +27,7 @@ from ...activations import ACT2FN from ...file_utils import ModelOutput from ...modeling_outputs import BackboneOutput -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import BackboneMixin, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from .configuration_maskformer_swin import MaskFormerSwinConfig @@ -837,7 +837,7 @@ def forward( ) -class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel): +class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin): """ MaskFormerSwin backbone, designed especially for the MaskFormer framework. diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py index c3d65ddc05e6..ebd134be5454 100644 --- a/src/transformers/models/resnet/modeling_resnet.py +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -28,7 +28,7 @@ BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import BackboneMixin, PreTrainedModel from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -431,7 +431,7 @@ def forward( """, RESNET_START_DOCSTRING, ) -class ResNetBackbone(ResNetPreTrainedModel): +class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin): def __init__(self, config): super().__init__(config)