Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
import collections
import gc
import inspect
import json
import os
import re
Expand Down Expand Up @@ -932,6 +933,15 @@ def floating_point_ops(
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)


class BackboneMixin:
def forward_with_filtered_kwargs(self, *args, **kwargs):

signature = dict(inspect.signature(self.forward).parameters)
filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature}

return self(*args, **filtered_kwargs)


class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
r"""
Base class for all models.
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/bit/modeling_bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import BackboneMixin, PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
Expand Down Expand Up @@ -842,7 +842,7 @@ def forward(
""",
BIT_START_DOCSTRING,
)
class BitBackbone(BitPreTrainedModel):
class BitBackbone(BitPreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ...activations import ACT2FN
from ...file_utils import ModelOutput
from ...modeling_outputs import BackboneOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import BackboneMixin, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from .configuration_maskformer_swin import MaskFormerSwinConfig

Expand Down Expand Up @@ -837,7 +837,7 @@ def forward(
)


class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel):
class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin):
"""
MaskFormerSwin backbone, designed especially for the MaskFormer framework.

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/resnet/modeling_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import BackboneMixin, PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
Expand Down Expand Up @@ -431,7 +431,7 @@ def forward(
""",
RESNET_START_DOCSTRING,
)
class ResNetBackbone(ResNetPreTrainedModel):
class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)

Expand Down