Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Configuration base class and utilities."""

import copy
import inspect
import json
import math
import os
Expand Down Expand Up @@ -74,6 +75,7 @@
# copied from huggingface_hub.dataclasses.strict when `accept_kwargs=True`
def wrap_init_to_accept_kwargs(cls: dataclass):
original_init = cls.__init__
original_signature = inspect.signature(original_init)

@wraps(original_init)
def __init__(self, *args, **kwargs: Any) -> None:
Expand Down Expand Up @@ -107,6 +109,8 @@ def __init__(self, *args, **kwargs: Any) -> None:

self.__post_init__(**additional_kwargs)

# Preserve the original signature for type checkers
__init__.__signature__ = original_signature
cls.__init__ = __init__
return cls

Expand Down
35 changes: 35 additions & 0 deletions tests/utils/test_configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,38 @@ def __init__(
self.assertIsInstance(new_config_instance.inf_positive, float)
self.assertIsInstance(new_config_instance.inf_negative, float)
self.assertIsInstance(new_config_instance.nan, float)

def test_pretrained_config_signature_preserved(self):
"""Test that PreTrainedConfig subclass __init__ signatures are preserved for type checkers.

Regression test for issue #45071: v5.4.0 breaks PretrainedConfig type checking.

When wrap_init_to_accept_kwargs wraps the dataclass-generated __init__, it should
preserve the original signature so type checkers like mypy can properly recognize
config parameters.
"""
import inspect

# Test with BertConfig which has multiple parameters
sig = inspect.signature(BertConfig.__init__)
params = list(sig.parameters.keys())

# Verify that common config parameters are in the signature
self.assertIn("self", params)
self.assertIn("vocab_size", params)
self.assertIn("hidden_size", params)
self.assertIn("num_hidden_layers", params)

# Verify that the signature is not just (*args, **kwargs)
# by checking that we have more than just 'self' and 'kwargs'
self.assertGreater(len(params), 2, "Signature should have concrete parameters, not just *args/**kwargs")

# Test that we can instantiate with the documented parameters
config = BertConfig(vocab_size=30522, hidden_size=768, num_hidden_layers=12)
self.assertEqual(config.vocab_size, 30522)
self.assertEqual(config.hidden_size, 768)
self.assertEqual(config.num_hidden_layers, 12)

# Test that we can still pass extra kwargs for backward compatibility
config_with_extra = BertConfig(vocab_size=30522, extra_kwarg_for_bc="should_be_handled")
self.assertEqual(config_with_extra.vocab_size, 30522)
Loading