Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
"GitPython<3.1.19",
"hf-doc-builder>=0.3.0",
"hf_xet",
"huggingface-hub>=0.30.0,<1.0",
"huggingface-hub>=0.31.4,<1.0",
"importlib_metadata",
"ipadic>=1.0.0,<2.0",
"isort>=5.5.4",
Expand Down
50 changes: 44 additions & 6 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,19 @@ def __getattribute__(self, key):
key = super().__getattribute__("attribute_map")[key]
return super().__getattribute__(key)

def __post_init__(self):
self._set_defaults()

def __init__(self, **kwargs):
self._set_defaults(**kwargs)

def _set_defaults(self, **kwargs):
def _get_default_if_unset(attribute_name, default_value):
set_attribute = getattr(self, attribute_name, None)
if set_attribute is not None:
return set_attribute
return kwargs.pop(attribute_name, default_value)

# Attributes with defaults
self.return_dict = kwargs.pop("return_dict", True)
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
Expand Down Expand Up @@ -267,12 +279,11 @@ def __init__(self, **kwargs):
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
self.tokenizer_class = kwargs.pop("tokenizer_class", None)
self.prefix = kwargs.pop("prefix", None)
self.bos_token_id = kwargs.pop("bos_token_id", None)
self.pad_token_id = kwargs.pop("pad_token_id", None)
self.eos_token_id = kwargs.pop("eos_token_id", None)
self.sep_token_id = kwargs.pop("sep_token_id", None)

self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
self.bos_token_id = _get_default_if_unset("bos_token_id", None)
self.pad_token_id = _get_default_if_unset("pad_token_id", None)
self.eos_token_id = _get_default_if_unset("eos_token_id", None)
self.sep_token_id = _get_default_if_unset("sep_token_id", None)
self.decoder_start_token_id = _get_default_if_unset("decoder_start_token_id", None)

# task specific arguments
self.task_specific_params = kwargs.pop("task_specific_params", None)
Expand Down Expand Up @@ -382,6 +393,29 @@ def _attn_implementation(self):
def _attn_implementation(self, value):
self._attn_implementation_internal = value

@property
def attn_implementation(self):
return self._attn_implementation

@attn_implementation.setter
def attn_implementation(self, value):
self._attn_implementation = value

def validate_token_ids(self):
"""Part of `@strict`-powered validation. Validates the contents of the special tokens."""
text_config = self.get_text_config()
vocab_size = getattr(text_config, "vocab_size", None)
if vocab_size is not None:
for token_name in ["pad_token_id", "bos_token_id", "eos_token_id"]:
token_id = getattr(text_config, token_name, None)
if token_id is not None and not 0 <= token_id < vocab_size:
# Can't be an exception until we can load configs that fail validation: several configs on the Hub
# store invalid special tokens, e.g. `pad_token_id=-1`
logger.warning_once(
f"Model config: {token_name} must be `None` or an integer within the vocabulary (between 0 "
f"and {vocab_size - 1}), got {token_id}. This may result in unexpected behavior."
)

def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
Expand Down Expand Up @@ -414,6 +448,10 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub:
UserWarning,
)

# Strict validation at save-time: prevent bad patterns from propagating
if hasattr(self, "validate"):
self.validate()

os.makedirs(save_directory, exist_ok=True)

if push_to_hub:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"GitPython": "GitPython<3.1.19",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"hf_xet": "hf_xet",
"huggingface-hub": "huggingface-hub>=0.30.0,<1.0",
"huggingface-hub": "huggingface-hub>=0.31.4,<1.0",
"importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0",
"isort": "isort>=5.5.4",
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2025,6 +2025,10 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
"`PretrainedConfig`. To create a model from a pretrained model use "
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
# class-level validation of config (as opposed to the attribute-level validation provided by `@strict`)
if hasattr(config, "validate"):
config.validate()
Comment on lines +2028 to +2030
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validates the config at model init time


Comment on lines +2028 to +2031
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# class-level validation of config (as opposed to the attribute-level validation provided by `@strict`)
if hasattr(config, "validate"):
config.validate()

not needed anymore

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible to load a config, modify it to an invalid state, and then use that config to instantiate a model :( as such, the model should ensure the config is valid before committing resources

if not getattr(config, "_attn_implementation_autoset", False):
# config usually has a `torch_dtype` but we need the next line for the `no_super_init` tests
dtype = config.torch_dtype if hasattr(config, "torch_dtype") else torch.get_default_dtype()
Expand Down
84 changes: 38 additions & 46 deletions src/transformers/models/albert/configuration_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,18 @@
"""ALBERT model configuration"""

from collections import OrderedDict
from typing import Mapping
from dataclasses import dataclass
from typing import Literal, Mapping, Optional, Union

from huggingface_hub.dataclasses import strict

from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...validators import activation_fn_key, interval, probability


@strict(accept_kwargs=True)
@dataclass
class AlbertConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`AlbertModel`] or a [`TFAlbertModel`]. It is used
Expand Down Expand Up @@ -53,9 +59,9 @@ class AlbertConfig(PretrainedConfig):
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu_new"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
hidden_dropout_prob (`float`, *optional*, defaults to 0):
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (`float`, *optional*, defaults to 0):
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
max_position_embeddings (`int`, *optional*, defaults to 512):
The maximum sequence length that this model might ever be used with. Typically set this to something large
Expand Down Expand Up @@ -103,51 +109,37 @@ class AlbertConfig(PretrainedConfig):
>>> configuration = model.config
```"""

vocab_size: int = interval(min=1)(default=30000)
embedding_size: int = interval(min=1)(default=128)
hidden_size: int = interval(min=1)(default=4096)
num_hidden_layers: int = interval(min=1)(default=12)
num_hidden_groups: int = interval(min=1)(default=1)
num_attention_heads: int = interval(min=0)(default=64)
intermediate_size: int = interval(min=1)(default=16384)
inner_group_num: int = interval(min=0)(default=1)
hidden_act: str = activation_fn_key(default="gelu_new")
Comment on lines +112 to +120
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW I think we should only use default here. let's be less verbose when not needed! Not even sure we need to specify this being an interval!

hidden_dropout_prob: Union[float, int] = probability(default=0.0)
attention_probs_dropout_prob: Union[float, int] = probability(default=0.0)
max_position_embeddings: int = interval(min=0)(default=512)
type_vocab_size: int = interval(min=1)(default=2)
initializer_range: float = interval(min=0.0)(default=0.02)
layer_norm_eps: float = interval(min=0.0)(default=1e-12)
classifier_dropout_prob: float = probability(default=0.1)
position_embedding_type: Literal["absolute", "relative_key", "relative_key_query"] = "absolute"
pad_token_id: Optional[int] = 0
bos_token_id: Optional[int] = 2
eos_token_id: Optional[int] = 3

# Not part of __init__
model_type = "albert"

def __init__(
self,
vocab_size=30000,
embedding_size=128,
hidden_size=4096,
num_hidden_layers=12,
num_hidden_groups=1,
num_attention_heads=64,
intermediate_size=16384,
inner_group_num=1,
hidden_act="gelu_new",
hidden_dropout_prob=0,
attention_probs_dropout_prob=0,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
classifier_dropout_prob=0.1,
position_embedding_type="absolute",
pad_token_id=0,
bos_token_id=2,
eos_token_id=3,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)

self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_hidden_groups = num_hidden_groups
self.num_attention_heads = num_attention_heads
self.inner_group_num = inner_group_num
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.classifier_dropout_prob = classifier_dropout_prob
self.position_embedding_type = position_embedding_type
def validate_architecture(self):
"""Part of `@strict`-powered validation. Validates the architecture of the config."""
if self.hidden_size % self.num_attention_heads != 0:
raise ValueError(
f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
f"heads ({self.num_attention_heads})."
)


# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Roberta->Albert
Expand Down
6 changes: 0 additions & 6 deletions src/transformers/models/albert/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,6 @@ def forward(
class AlbertAttention(nn.Module):
def __init__(self, config: AlbertConfig):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads}"
)

self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.attention_head_size = config.hidden_size // config.num_attention_heads
Expand Down
106 changes: 106 additions & 0 deletions src/transformers/validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# coding=utf-8
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Validators to be used with `huggingface_hub.dataclasses.validated_field`. We recommend using the validator(s) that best
describe the constraints of your dataclass fields, for the best user experience (e.g. better error messages).
"""

from typing import Callable, Optional, Union

from huggingface_hub.dataclasses import as_validated_field

from .utils import is_torch_available


if is_torch_available():
from .activations import ACT2FN
else:
ACT2FN = {}


# Numerical validators


def interval(
min: Optional[Union[int, float]] = None,
max: Optional[Union[int, float]] = None,
exclude_min: bool = False,
exclude_max: bool = False,
) -> Callable:
"""
Parameterized validator that ensures that `value` is within the defined interval. Optionally, the interval can be
open on either side. Expected usage: `interval(min=0)(default=8)`

Args:
min (`int` or `float`, *optional*):
Minimum value of the interval.
max (`int` or `float`, *optional*):
Maximum value of the interval.
exclude_min (`bool`, *optional*, defaults to `False`):
If True, the minimum value is excluded from the interval.
exclude_max (`bool`, *optional*, defaults to `False`):
If True, the maximum value is excluded from the interval.
"""
error_message = "Value must be"
if min is not None:
if exclude_min:
error_message += f" greater than {min}"
else:
error_message += f" greater or equal to {min}"
if min is not None and max is not None:
error_message += " and"
if max is not None:
if exclude_max:
error_message += f" smaller than {max}"
else:
error_message += f" smaller or equal to {max}"
error_message += ", got {value}."

min = min or float("-inf")
max = max or float("inf")

@as_validated_field
def _inner(value: Union[int, float]):
min_valid = min <= value if not exclude_min else min < value
max_valid = value <= max if not exclude_max else value < max
if not (min_valid and max_valid):
raise ValueError(error_message.format(value=value))

return _inner


@as_validated_field
def probability(value: float):
"""Ensures that `value` is a valid probability number, i.e. [0,1]."""
if not 0 <= value <= 1:
raise ValueError(f"Value must be a probability between 0.0 and 1.0, got {value}.")


# String validators


@as_validated_field
def activation_fn_key(value: str):
"""Ensures that `value` is a string corresponding to an activation function."""
# TODO (joao): in python 3.11+, we can build a Literal type from the keys of ACT2FN
if len(ACT2FN) > 0: # don't validate if we can't import ACT2FN
if value not in ACT2FN:
raise ValueError(
f"Value must be one of {list(ACT2FN.keys())}, got {value}. "
"Make sure to use a string that corresponds to an activation function."
)


__all__ = ["interval", "probability", "activation_fn_key"]
2 changes: 1 addition & 1 deletion tests/models/albert/test_modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):

def setUp(self):
self.model_tester = AlbertModelTester(self)
self.config_tester = ConfigTester(self, config_class=AlbertConfig, hidden_size=37)
self.config_tester = ConfigTester(self, config_class=AlbertConfig, hidden_size=64)

def test_config(self):
self.config_tester.run_common_tests()
Expand Down
17 changes: 11 additions & 6 deletions tests/test_configuration_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,28 @@ def create_and_test_config_common_properties(self):
self.parent.assertTrue(hasattr(config, prop), msg=f"`{prop}` does not exist")

# Test that config has the common properties as setter
for idx, name in enumerate(common_properties):
dummy_value = 64
for name in common_properties:
try:
setattr(config, name, idx)
setattr(config, name, dummy_value)
self.parent.assertEqual(
getattr(config, name), idx, msg=f"`{name} value {idx} expected, but was {getattr(config, name)}"
getattr(config, name),
dummy_value,
msg=f"`{name} value {dummy_value} expected, but was {getattr(config, name)}",
)
except NotImplementedError:
# Some models might not be able to implement setters for common_properties
# In that case, a NotImplementedError is raised
pass

# Test if config class can be called with Config(prop_name=..)
for idx, name in enumerate(common_properties):
for name in common_properties:
try:
config = self.config_class(**{name: idx})
config = self.config_class(**{name: dummy_value})
self.parent.assertEqual(
getattr(config, name), idx, msg=f"`{name} value {idx} expected, but was {getattr(config, name)}"
getattr(config, name),
dummy_value,
msg=f"`{name} value {dummy_value} expected, but was {getattr(config, name)}",
)
except NotImplementedError:
# Some models might not be able to implement setters for common_properties
Expand Down
Loading