From 4b3a9548b039cd242105a761d95abadab01c1ee9 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 4 Mar 2025 13:10:47 +0000 Subject: [PATCH 01/18] mvp --- .../models/albert/configuration_albert.py | 98 ++++++++++--------- .../models/albert/modeling_albert.py | 7 +- src/transformers/validators.py | 76 ++++++++++++++ 3 files changed, 130 insertions(+), 51 deletions(-) create mode 100644 src/transformers/validators.py diff --git a/src/transformers/models/albert/configuration_albert.py b/src/transformers/models/albert/configuration_albert.py index e1e2d4547cc4..65d4d932293d 100644 --- a/src/transformers/models/albert/configuration_albert.py +++ b/src/transformers/models/albert/configuration_albert.py @@ -18,10 +18,22 @@ from collections import OrderedDict from typing import Mapping +from huggingface_hub.utils import strict_dataclass, validated_field + from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfig - - +from ...validators import ( + activation_function_key, + choice_str, + positive_float, + positive_int, + probability, + strictly_positive_int, + vocabulary_token, +) + + +@strict_dataclass class AlbertConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`AlbertModel`] or a [`TFAlbertModel`]. It is used @@ -103,51 +115,47 @@ class AlbertConfig(PretrainedConfig): >>> configuration = model.config ```""" + vocab_size: int = validated_field(strictly_positive_int, default=30000) + embedding_size: int = validated_field(strictly_positive_int, default=128) + hidden_size: int = validated_field(strictly_positive_int, default=4096) + num_hidden_layers: int = validated_field(strictly_positive_int, default=12) + num_hidden_groups: int = validated_field(strictly_positive_int, default=1) + num_attention_heads: int = validated_field(positive_int, default=64) + intermediate_size: int = validated_field(strictly_positive_int, default=16384) + inner_group_num: int = validated_field(positive_int, default=1) + hidden_act: str = validated_field(activation_function_key, default="gelu_new") + hidden_dropout_prob: float = validated_field(probability, default=0) + attention_probs_dropout_prob: float = validated_field(probability, default=0) + max_position_embeddings: int = validated_field(positive_int, default=512) + type_vocab_size: int = validated_field(strictly_positive_int, default=2) + initializer_range: float = validated_field(positive_float, default=0.02) + layer_norm_eps: float = validated_field(positive_float, default=1e-12) + classifier_dropout_prob: float = validated_field(probability, default=0.1) + position_embedding_type: str = validated_field( + choice_str, choices=["absolute", "relative_key", "relative_key_query"], default="absolute" + ) + pad_token_id: int = validated_field(vocabulary_token, vocab_size=vocab_size, default=0) + bos_token_id: int = validated_field(vocabulary_token, vocab_size=vocab_size, default=2) + eos_token_id: int = validated_field(vocabulary_token, vocab_size=vocab_size, default=3) + + # Not part of __init__ model_type = "albert" - def __init__( - self, - vocab_size=30000, - embedding_size=128, - hidden_size=4096, - num_hidden_layers=12, - num_hidden_groups=1, - num_attention_heads=64, - intermediate_size=16384, - inner_group_num=1, - hidden_act="gelu_new", - hidden_dropout_prob=0, - attention_probs_dropout_prob=0, - max_position_embeddings=512, - type_vocab_size=2, - initializer_range=0.02, - layer_norm_eps=1e-12, - classifier_dropout_prob=0.1, - position_embedding_type="absolute", - pad_token_id=0, - bos_token_id=2, - eos_token_id=3, - **kwargs, - ): - super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) - - self.vocab_size = vocab_size - self.embedding_size = embedding_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_hidden_groups = num_hidden_groups - self.num_attention_heads = num_attention_heads - self.inner_group_num = inner_group_num - self.hidden_act = hidden_act - self.intermediate_size = intermediate_size - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.initializer_range = initializer_range - self.layer_norm_eps = layer_norm_eps - self.classifier_dropout_prob = classifier_dropout_prob - self.position_embedding_type = position_embedding_type + def __init__(self, **kwargs): + super().__init__( + pad_token_id=self.pad_token_id, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, **kwargs + ) + + def __post_init__(self): + self.validate() + + def validate(self): + """Ensures the configuration is valid.""" + if self.hidden_size % self.num_attention_heads != 0 and not hasattr(self, "embedding_size"): + raise ValueError( + f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention " + f"heads ({self.num_attention_heads}" + ) # Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Roberta->Albert diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 11fd1f939ccd..809e256d4e5c 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -245,12 +245,6 @@ def forward( class AlbertAttention(nn.Module): def __init__(self, config: AlbertConfig): super().__init__() - if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " - f"heads ({config.num_attention_heads}" - ) - self.num_attention_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.attention_head_size = config.hidden_size // config.num_attention_heads @@ -898,6 +892,7 @@ class AlbertForMaskedLM(AlbertPreTrainedModel): _tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"] def __init__(self, config): + config.validate() super().__init__(config) self.albert = AlbertModel(config, add_pooling_layer=False) diff --git a/src/transformers/validators.py b/src/transformers/validators.py new file mode 100644 index 000000000000..224311f98928 --- /dev/null +++ b/src/transformers/validators.py @@ -0,0 +1,76 @@ +# coding=utf-8 +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Validators to be used with `huggingface_hub.utils.strict_dataclass`. We recommend using the validator(s) that best +describe the constraints of your dataclass fields, for a best user experience (e.g. better error messages). +""" + +from typing import Iterable +from .activations import ACT2CLS + + +# Integer validators + + +def positive_int(value: int): + """Ensures that `value` is a positive integer (including 0).""" + if not value >= 0: + raise ValueError(f"Value must be a positive integer, got {value}.") + + +def strictly_positive_int(value: int): + """Ensures that `value` is a positive integer (excluding 0).""" + if not value > 0: + raise ValueError(f"Value must be a strictly positive integer, got {value}.") + + +def vocabulary_token(value: int, vocab_size: int): + """Ensures that `value` is a valid vocabulary token index.""" + if not 0 <= value < vocab_size: + raise ValueError(f"Value must be a token in the vocabulary, got {value}. (vocabulary size = {vocab_size})") + + + +# Float validators + + +def positive_float(value: float): + """Ensures that `value` is a positive float (including 0.0).""" + if not value >= 0: + raise ValueError(f"Value must be a positive float, got {value}.") + + +def probability(value: float): + """Ensures that `value` is a valid probability number, i.e. [0,1].""" + if not 0 <= value <= 1: + raise ValueError(f"Value must be a probability between 0.0 and 1.0, got {value}.") + + +# String validators + + +def activation_function_key(value: str): + """Ensures that `value` is a string corresponding to an activation function.""" + if value not in ACT2CLS: + raise ValueError( + f"Value must be one of {list(ACT2CLS.keys())}, got {value}. " + "Make sure to use a string that corresponds to an activation function." + ) + + +def choice_str(value: str, choices: Iterable[str]): + """Ensures that `value` is one of the choices in `choices`.""" + if value not in choices: + raise ValueError(f"Value must be one of {choices}, got {value}") From 62b3e1242a379ff37416d3589b1d466e4b9717fd Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 4 Mar 2025 13:25:59 +0000 Subject: [PATCH 02/18] validate in PreTrainedModel --- src/transformers/modeling_utils.py | 2 ++ src/transformers/models/albert/modeling_albert.py | 1 - src/transformers/validators.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index bd09c1ae57d1..c77faa86a4e9 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2025,6 +2025,8 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): "`PretrainedConfig`. To create a model from a pretrained model use " f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`" ) + if hasattr(config, "validate"): # e.g. in @strict_dataclass + config.validate() if not getattr(config, "_attn_implementation_autoset", False): # config usually has a `torch_dtype` but we need the next line for the `no_super_init` tests dtype = config.torch_dtype if hasattr(config, "torch_dtype") else torch.get_default_dtype() diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 809e256d4e5c..640bc1a56014 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -892,7 +892,6 @@ class AlbertForMaskedLM(AlbertPreTrainedModel): _tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"] def __init__(self, config): - config.validate() super().__init__(config) self.albert = AlbertModel(config, add_pooling_layer=False) diff --git a/src/transformers/validators.py b/src/transformers/validators.py index 224311f98928..c389082b5479 100644 --- a/src/transformers/validators.py +++ b/src/transformers/validators.py @@ -18,6 +18,7 @@ """ from typing import Iterable + from .activations import ACT2CLS @@ -42,7 +43,6 @@ def vocabulary_token(value: int, vocab_size: int): raise ValueError(f"Value must be a token in the vocabulary, got {value}. (vocabulary size = {vocab_size})") - # Float validators From 654845eb6d0fcd177dc05adfa8442d8db0aea2f8 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 4 Mar 2025 13:47:31 +0000 Subject: [PATCH 03/18] validation --- src/transformers/models/albert/configuration_albert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/albert/configuration_albert.py b/src/transformers/models/albert/configuration_albert.py index 65d4d932293d..f5eec73be44a 100644 --- a/src/transformers/models/albert/configuration_albert.py +++ b/src/transformers/models/albert/configuration_albert.py @@ -151,7 +151,7 @@ def __post_init__(self): def validate(self): """Ensures the configuration is valid.""" - if self.hidden_size % self.num_attention_heads != 0 and not hasattr(self, "embedding_size"): + if self.hidden_size % self.num_attention_heads != 0: raise ValueError( f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention " f"heads ({self.num_attention_heads}" From 6318d84c6598327330cba1dd1fd2eed1d5da50ea Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Sun, 9 Mar 2025 20:20:41 +0000 Subject: [PATCH 04/18] shorter syntax; NOT FULLY OPERATIONAL --- .../models/albert/configuration_albert.py | 64 +++++++++---------- src/transformers/validators.py | 54 +++++++--------- 2 files changed, 54 insertions(+), 64 deletions(-) diff --git a/src/transformers/models/albert/configuration_albert.py b/src/transformers/models/albert/configuration_albert.py index f5eec73be44a..ad194864b434 100644 --- a/src/transformers/models/albert/configuration_albert.py +++ b/src/transformers/models/albert/configuration_albert.py @@ -16,21 +16,13 @@ """ALBERT model configuration""" from collections import OrderedDict -from typing import Mapping +from typing import Literal, Mapping, Optional from huggingface_hub.utils import strict_dataclass, validated_field from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfig -from ...validators import ( - activation_function_key, - choice_str, - positive_float, - positive_int, - probability, - strictly_positive_int, - vocabulary_token, -) +from ...validators import activation_fn_key, interval, probability @strict_dataclass @@ -65,9 +57,9 @@ class AlbertConfig(PretrainedConfig): hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu_new"`): The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, `"relu"`, `"silu"` and `"gelu_new"` are supported. - hidden_dropout_prob (`float`, *optional*, defaults to 0): + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. - attention_probs_dropout_prob (`float`, *optional*, defaults to 0): + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. max_position_embeddings (`int`, *optional*, defaults to 512): The maximum sequence length that this model might ever be used with. Typically set this to something large @@ -115,28 +107,26 @@ class AlbertConfig(PretrainedConfig): >>> configuration = model.config ```""" - vocab_size: int = validated_field(strictly_positive_int, default=30000) - embedding_size: int = validated_field(strictly_positive_int, default=128) - hidden_size: int = validated_field(strictly_positive_int, default=4096) - num_hidden_layers: int = validated_field(strictly_positive_int, default=12) - num_hidden_groups: int = validated_field(strictly_positive_int, default=1) - num_attention_heads: int = validated_field(positive_int, default=64) - intermediate_size: int = validated_field(strictly_positive_int, default=16384) - inner_group_num: int = validated_field(positive_int, default=1) - hidden_act: str = validated_field(activation_function_key, default="gelu_new") - hidden_dropout_prob: float = validated_field(probability, default=0) - attention_probs_dropout_prob: float = validated_field(probability, default=0) - max_position_embeddings: int = validated_field(positive_int, default=512) - type_vocab_size: int = validated_field(strictly_positive_int, default=2) - initializer_range: float = validated_field(positive_float, default=0.02) - layer_norm_eps: float = validated_field(positive_float, default=1e-12) + vocab_size: int = validated_field(interval(min=1), default=30000) + embedding_size: int = validated_field(interval(min=1), default=128) + hidden_size: int = validated_field(interval(min=1), default=4096) + num_hidden_layers: int = validated_field(interval(min=1), default=12) + num_hidden_groups: int = validated_field(interval(min=1), default=1) + num_attention_heads: int = validated_field(interval(min=0), default=64) + intermediate_size: int = validated_field(interval(min=1), default=16384) + inner_group_num: int = validated_field(interval(min=0), default=1) + hidden_act: str = validated_field(activation_fn_key, default="gelu_new") + hidden_dropout_prob: float = validated_field(probability, default=0.0) + attention_probs_dropout_prob: float = validated_field(probability, default=0.0) + max_position_embeddings: int = validated_field(interval(min=0), default=512) + type_vocab_size: int = validated_field(interval(min=1), default=2) + initializer_range: float = validated_field(interval(min=0.0), default=0.02) + layer_norm_eps: float = validated_field(interval(min=0.0), default=1e-12) classifier_dropout_prob: float = validated_field(probability, default=0.1) - position_embedding_type: str = validated_field( - choice_str, choices=["absolute", "relative_key", "relative_key_query"], default="absolute" - ) - pad_token_id: int = validated_field(vocabulary_token, vocab_size=vocab_size, default=0) - bos_token_id: int = validated_field(vocabulary_token, vocab_size=vocab_size, default=2) - eos_token_id: int = validated_field(vocabulary_token, vocab_size=vocab_size, default=3) + position_embedding_type: Literal["absolute", "relative_key", "relative_key_query"] = "absolute" + pad_token_id: Optional[int] = 0 + bos_token_id: Optional[int] = 2 + eos_token_id: Optional[int] = 3 # Not part of __init__ model_type = "albert" @@ -157,6 +147,14 @@ def validate(self): f"heads ({self.num_attention_heads}" ) + # Special tokens must be in the vocabulary + for token_name in ["pad_token_id", "bos_token_id", "eos_token_id"]: + token_id = getattr(self, token_name) + if token_id is not None and not 0 <= token_id < self.vocab_size: + raise ValueError( + f"{token_name} must be in the vocabulary with size {self.vocab_size}, got {token_id}." + ) + # Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Roberta->Albert class AlbertOnnxConfig(OnnxConfig): diff --git a/src/transformers/validators.py b/src/transformers/validators.py index c389082b5479..ab95ab6b7961 100644 --- a/src/transformers/validators.py +++ b/src/transformers/validators.py @@ -17,39 +17,36 @@ describe the constraints of your dataclass fields, for a best user experience (e.g. better error messages). """ -from typing import Iterable +from typing import Callable, Optional from .activations import ACT2CLS -# Integer validators +# Numerical validators -def positive_int(value: int): - """Ensures that `value` is a positive integer (including 0).""" - if not value >= 0: - raise ValueError(f"Value must be a positive integer, got {value}.") +def interval(min: Optional[int | float] = None, max: Optional[int | float] = None) -> Callable: + """ + Parameterized validator that ensures that `value` is within the defined interval. + Expected usage: `validated_field(interval(min=0), default=8)` + """ + error_message = "Value must be" + if min is not None: + error_message += f" at least {min}" + if min is not None and max is not None: + error_message += " and" + if max is not None: + error_message += f" at most {max}" + error_message += ", got {value}." + min = min or float("-inf") + max = max or float("inf") -def strictly_positive_int(value: int): - """Ensures that `value` is a positive integer (excluding 0).""" - if not value > 0: - raise ValueError(f"Value must be a strictly positive integer, got {value}.") + def _inner(value: int | float): + if not min <= value <= max: + raise ValueError(error_message.format(value=value)) - -def vocabulary_token(value: int, vocab_size: int): - """Ensures that `value` is a valid vocabulary token index.""" - if not 0 <= value < vocab_size: - raise ValueError(f"Value must be a token in the vocabulary, got {value}. (vocabulary size = {vocab_size})") - - -# Float validators - - -def positive_float(value: float): - """Ensures that `value` is a positive float (including 0.0).""" - if not value >= 0: - raise ValueError(f"Value must be a positive float, got {value}.") + return _inner def probability(value: float): @@ -61,16 +58,11 @@ def probability(value: float): # String validators -def activation_function_key(value: str): +def activation_fn_key(value: str): """Ensures that `value` is a string corresponding to an activation function.""" + # TODO (joao): in python 3.11+, we can build a Literal type from the keys of ACT2CLS if value not in ACT2CLS: raise ValueError( f"Value must be one of {list(ACT2CLS.keys())}, got {value}. " "Make sure to use a string that corresponds to an activation function." ) - - -def choice_str(value: str, choices: Iterable[str]): - """Ensures that `value` is one of the choices in `choices`.""" - if value not in choices: - raise ValueError(f"Value must be one of {choices}, got {value}") From 7cc2a9a8b59f043db49e603f0ac003d5b150b2cc Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 10 Mar 2025 17:11:06 +0000 Subject: [PATCH 05/18] almost working --- .../models/albert/configuration_albert.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/albert/configuration_albert.py b/src/transformers/models/albert/configuration_albert.py index ad194864b434..9cd707d8beb5 100644 --- a/src/transformers/models/albert/configuration_albert.py +++ b/src/transformers/models/albert/configuration_albert.py @@ -131,28 +131,32 @@ class AlbertConfig(PretrainedConfig): # Not part of __init__ model_type = "albert" - def __init__(self, **kwargs): + def __post_init__(self): + """Called after `__init__` from the dataclass: initializes parent classes and validates the instance.""" super().__init__( - pad_token_id=self.pad_token_id, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, **kwargs + pad_token_id=self.pad_token_id, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + # **kwargs -> this is missing ) - - def __post_init__(self): self.validate() def validate(self): - """Ensures the configuration is valid.""" + """Ensures the configuration is valid by assessing combinations of arguments.""" + # Architecture validation if self.hidden_size % self.num_attention_heads != 0: raise ValueError( f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention " f"heads ({self.num_attention_heads}" ) - # Special tokens must be in the vocabulary + # Token validation for token_name in ["pad_token_id", "bos_token_id", "eos_token_id"]: token_id = getattr(self, token_name) if token_id is not None and not 0 <= token_id < self.vocab_size: raise ValueError( - f"{token_name} must be in the vocabulary with size {self.vocab_size}, got {token_id}." + f"{token_name} must be in the vocabulary with size {self.vocab_size}, i.e. between 0 and " + f"{self.vocab_size - 1}, got {token_id}." ) From 0c7c4c5dbb9431fe56df1cf8a948739c96d7e3e8 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 2 May 2025 11:22:48 +0000 Subject: [PATCH 06/18] update with latest hub changes --- src/transformers/modeling_utils.py | 4 +- .../models/albert/configuration_albert.py | 22 ++- src/transformers/validators.py | 48 +++++-- tests/utils/test_validators.py | 135 ++++++++++++++++++ 4 files changed, 187 insertions(+), 22 deletions(-) create mode 100644 tests/utils/test_validators.py diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c77faa86a4e9..42aa36210d73 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2025,8 +2025,10 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): "`PretrainedConfig`. To create a model from a pretrained model use " f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`" ) - if hasattr(config, "validate"): # e.g. in @strict_dataclass + # class-level validation of config (as opposed to the attribute-level validation provided by `@strict`) + if hasattr(config, "validate"): config.validate() + if not getattr(config, "_attn_implementation_autoset", False): # config usually has a `torch_dtype` but we need the next line for the `no_super_init` tests dtype = config.torch_dtype if hasattr(config, "torch_dtype") else torch.get_default_dtype() diff --git a/src/transformers/models/albert/configuration_albert.py b/src/transformers/models/albert/configuration_albert.py index 9cd707d8beb5..789f2ccc250a 100644 --- a/src/transformers/models/albert/configuration_albert.py +++ b/src/transformers/models/albert/configuration_albert.py @@ -16,16 +16,18 @@ """ALBERT model configuration""" from collections import OrderedDict +from dataclasses import dataclass from typing import Literal, Mapping, Optional -from huggingface_hub.utils import strict_dataclass, validated_field +from huggingface_hub.dataclasses import strict, validated_field from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfig -from ...validators import activation_fn_key, interval, probability +from ...validators import activation_fn_key, interval, probability, token -@strict_dataclass +@strict(accept_kwargs=True) +@dataclass class AlbertConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`AlbertModel`] or a [`TFAlbertModel`]. It is used @@ -124,21 +126,15 @@ class AlbertConfig(PretrainedConfig): layer_norm_eps: float = validated_field(interval(min=0.0), default=1e-12) classifier_dropout_prob: float = validated_field(probability, default=0.1) position_embedding_type: Literal["absolute", "relative_key", "relative_key_query"] = "absolute" - pad_token_id: Optional[int] = 0 - bos_token_id: Optional[int] = 2 - eos_token_id: Optional[int] = 3 + pad_token_id: Optional[int] = validated_field(token, default=0) + bos_token_id: Optional[int] = validated_field(token, default=2) + eos_token_id: Optional[int] = validated_field(token, default=3) # Not part of __init__ model_type = "albert" def __post_init__(self): - """Called after `__init__` from the dataclass: initializes parent classes and validates the instance.""" - super().__init__( - pad_token_id=self.pad_token_id, - bos_token_id=self.bos_token_id, - eos_token_id=self.eos_token_id, - # **kwargs -> this is missing - ) + """Called after `__init__`: validates the instance.""" self.validate() def validate(self): diff --git a/src/transformers/validators.py b/src/transformers/validators.py index ab95ab6b7961..71083a8b0ec4 100644 --- a/src/transformers/validators.py +++ b/src/transformers/validators.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Validators to be used with `huggingface_hub.utils.strict_dataclass`. We recommend using the validator(s) that best -describe the constraints of your dataclass fields, for a best user experience (e.g. better error messages). +Validators to be used with `huggingface_hub.dataclasses.validated_field`. We recommend using the validator(s) that best +describe the constraints of your dataclass fields, for the best user experience (e.g. better error messages). """ from typing import Callable, Optional @@ -25,25 +25,48 @@ # Numerical validators -def interval(min: Optional[int | float] = None, max: Optional[int | float] = None) -> Callable: +def interval( + min: Optional[int | float] = None, + max: Optional[int | float] = None, + exclude_min: bool = False, + exclude_max: bool = False, +) -> Callable: """ - Parameterized validator that ensures that `value` is within the defined interval. - Expected usage: `validated_field(interval(min=0), default=8)` + Parameterized validator that ensures that `value` is within the defined interval. Optionally, the interval can be + open on either side. Expected usage: `validated_field(interval(min=0), default=8)` + + Args: + min (`int` or `float`, *optional*): + Minimum value of the interval. + max (`int` or `float`, *optional*): + Maximum value of the interval. + exclude_min (`bool`, *optional*, defaults to `False`): + If True, the minimum value is excluded from the interval. + exclude_max (`bool`, *optional*, defaults to `False`): + If True, the maximum value is excluded from the interval. """ error_message = "Value must be" if min is not None: - error_message += f" at least {min}" + if exclude_min: + error_message += f" greater than {min}" + else: + error_message += f" greater or equal to {min}" if min is not None and max is not None: error_message += " and" if max is not None: - error_message += f" at most {max}" + if exclude_max: + error_message += f" smaller than {max}" + else: + error_message += f" smaller or equal to {max}" error_message += ", got {value}." min = min or float("-inf") max = max or float("inf") def _inner(value: int | float): - if not min <= value <= max: + min_valid = min <= value if not exclude_min else min < value + max_valid = value <= max if not exclude_max else value < max + if not (min_valid and max_valid): raise ValueError(error_message.format(value=value)) return _inner @@ -55,6 +78,12 @@ def probability(value: float): raise ValueError(f"Value must be a probability between 0.0 and 1.0, got {value}.") +def token(value: Optional[int]): + """Ensures that `value` is a potential token. A token, when set, must be a non-negative integer.""" + if value is not None and value < 0: + raise ValueError(f"A token, when set, must be a non-negative integer, got {value}.") + + # String validators @@ -66,3 +95,6 @@ def activation_fn_key(value: str): f"Value must be one of {list(ACT2CLS.keys())}, got {value}. " "Make sure to use a string that corresponds to an activation function." ) + + +__all__ = ["interval", "probability", "token", "activation_fn_key"] diff --git a/tests/utils/test_validators.py b/tests/utils/test_validators.py new file mode 100644 index 000000000000..b309857797be --- /dev/null +++ b/tests/utils/test_validators.py @@ -0,0 +1,135 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from transformers import AlbertConfig +from transformers.validators import interval, probability, token, activation_fn_key +from huggingface_hub.dataclasses import StrictDataclassFieldValidationError + + +class ValidatorsTests(unittest.TestCase): + """ + Sanity check tests for the validators. Note that the validators do not perform strict type checking + (`huggingface_hub.dataclasses.strict` is used for that). + """ + + def test_interval(self): + # valid + interval(1, 10)(5) + interval(1, 10)(5.0) + interval(1, 10)(10) + interval(1, 10)(1) + interval(1, 10, exclude_min=True)(1.0000001) + + # invalid + with self.assertRaises(ValueError): + interval(1, 10)(11) # greater than max + with self.assertRaises(ValueError): + interval(1, 10)(0.9999999) # less than min + with self.assertRaises(ValueError): + interval(1, 10, exclude_max=True)(10) # equal to max, but exclude_max is True + with self.assertRaises(ValueError): + interval(1, 10, exclude_min=True)(1.0) # equal to min, but exclude_min is True + with self.assertRaises(ValueError): + interval(1, 10)(-5) # less than min + + def test_probability(self): + # valid + probability(0.5) + probability(0) + probability(1) + + # invalid + with self.assertRaises(ValueError): + probability(99) # 0-1 probabilities only + with self.assertRaises(ValueError): + probability(1.1) # greater than 1 + with self.assertRaises(ValueError): + probability(-0.1) # less than 0 + + def test_token(self): + # valid + token(None) + token(0) + token(1) + token(999999999) + + # invalid + with self.assertRaises(ValueError): + token(-1) # less than 0 + with self.assertRaises(TypeError): + token("") # must be the token id, not its string counterpart + + def test_activation_fn_key(self): + # valid + activation_fn_key("relu") + activation_fn_key("gelu") + + # invalid + with self.assertRaises(ValueError): + activation_fn_key("foo") # obvious one + with self.assertRaises(ValueError): + activation_fn_key(None) # can't be None + with self.assertRaises(ValueError): + activation_fn_key("Relu") # typo: should be "relu", not "Relu" + + +class ValidatorsIntegrationTests(unittest.TestCase): + """Tests in which the validators are used as part of another class/function""" + + def test_model_config_validation(self): + """Sanity check tests for the integration of model config with `huggingface_hub.dataclasses.strict`""" + # 1 - We can initialize the config, including with arbitrary kwargs + config = AlbertConfig() + config = AlbertConfig(eos_token_id=5) + self.assertEqual(config.eos_token_id, 5) + config = AlbertConfig(eos_token_id=None) + self.assertIsNone(config.eos_token_id) + config = AlbertConfig(foo="bar") # Ensures backwards compatibility + self.assertEqual(config.foo, "bar") + + # 2 - Manual specification, traveling through an invalid config, should be allowed + config.eos_token_id = 99 # vocab_size = 30000, eos_token_id = 99 -> valid + config.vocab_size = 10 # vocab_size = 10, eos_token_id = 99 -> invalid (but only throws error in `validate()`) + with self.assertRaises(ValueError): + config.validate() + config.eos_token_id = 9 # vocab_size = 10, eos_token_id = 9 -> valid + config.validate() + + # 3 - These cases should raise an error + + # vocab_size is an int + with self.assertRaises(StrictDataclassFieldValidationError): + config = AlbertConfig(vocab_size=10.0) + + # num_hidden_layers is an int + with self.assertRaises(StrictDataclassFieldValidationError): + config = AlbertConfig(num_hidden_layers=None) + + # position_embedding_type is a Literal, foo is not one of the options + with self.assertRaises(StrictDataclassFieldValidationError): + config = AlbertConfig(position_embedding_type="foo") + + # eos_token_id is a token, and must be non-negative + with self.assertRaises(StrictDataclassFieldValidationError): + config = AlbertConfig(eos_token_id=-1) + + # `validate()` is called in `__post_init__`, i.e. after `__init__`. A special token must be in the vocabulary. + with self.assertRaises(ValueError): + config = AlbertConfig(vocab_size=10, eos_token_id=99) + + # vocab size is assigned after init, individual attributes are checked on assignment + with self.assertRaises(StrictDataclassFieldValidationError): + config = AlbertConfig() + config.vocab_size = "foo" From 4b14696947efa46bd78c91bb148ff9b0d555bb46 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 14 May 2025 14:10:49 +0000 Subject: [PATCH 07/18] update with latest validator changes (shorter syntax) --- src/transformers/configuration_utils.py | 16 +++ .../models/albert/configuration_albert.py | 52 ++++---- src/transformers/validators.py | 6 + tests/utils/test_validators.py | 126 +++++++++++------- 4 files changed, 126 insertions(+), 74 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 205a7dde8f28..5ad5c822d86e 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -382,6 +382,22 @@ def _attn_implementation(self): def _attn_implementation(self, value): self._attn_implementation_internal = value + def validate(self): + """ + Validates the contents of the config. + """ + # Special token validation + text_config = self.get_text_config() + vocab_size = getattr(text_config, "vocab_size", None) + if vocab_size is not None: + for token_name in ["pad_token_id", "bos_token_id", "eos_token_id"]: + token_id = getattr(text_config, token_name, None) + if token_id is not None and not 0 <= token_id < vocab_size: + raise ValueError( + f"{token_name} must be `None` or an integer within the vocabulary, " + f"i.e. between 0 and {vocab_size - 1}, got {token_id}." + ) + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): """ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the diff --git a/src/transformers/models/albert/configuration_albert.py b/src/transformers/models/albert/configuration_albert.py index 789f2ccc250a..4f612ee197dd 100644 --- a/src/transformers/models/albert/configuration_albert.py +++ b/src/transformers/models/albert/configuration_albert.py @@ -19,7 +19,7 @@ from dataclasses import dataclass from typing import Literal, Mapping, Optional -from huggingface_hub.dataclasses import strict, validated_field +from huggingface_hub.dataclasses import strict from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfig @@ -109,26 +109,26 @@ class AlbertConfig(PretrainedConfig): >>> configuration = model.config ```""" - vocab_size: int = validated_field(interval(min=1), default=30000) - embedding_size: int = validated_field(interval(min=1), default=128) - hidden_size: int = validated_field(interval(min=1), default=4096) - num_hidden_layers: int = validated_field(interval(min=1), default=12) - num_hidden_groups: int = validated_field(interval(min=1), default=1) - num_attention_heads: int = validated_field(interval(min=0), default=64) - intermediate_size: int = validated_field(interval(min=1), default=16384) - inner_group_num: int = validated_field(interval(min=0), default=1) - hidden_act: str = validated_field(activation_fn_key, default="gelu_new") - hidden_dropout_prob: float = validated_field(probability, default=0.0) - attention_probs_dropout_prob: float = validated_field(probability, default=0.0) - max_position_embeddings: int = validated_field(interval(min=0), default=512) - type_vocab_size: int = validated_field(interval(min=1), default=2) - initializer_range: float = validated_field(interval(min=0.0), default=0.02) - layer_norm_eps: float = validated_field(interval(min=0.0), default=1e-12) - classifier_dropout_prob: float = validated_field(probability, default=0.1) + vocab_size: int = interval(min=1)(default=30000) + embedding_size: int = interval(min=1)(default=128) + hidden_size: int = interval(min=1)(default=4096) + num_hidden_layers: int = interval(min=1)(default=12) + num_hidden_groups: int = interval(min=1)(default=1) + num_attention_heads: int = interval(min=0)(default=64) + intermediate_size: int = interval(min=1)(default=16384) + inner_group_num: int = interval(min=0)(default=1) + hidden_act: str = activation_fn_key(default="gelu_new") + hidden_dropout_prob: float = probability(default=0.0) + attention_probs_dropout_prob: float = probability(default=0.0) + max_position_embeddings: int = interval(min=0)(default=512) + type_vocab_size: int = interval(min=1)(default=2) + initializer_range: float = interval(min=0.0)(default=0.02) + layer_norm_eps: float = interval(min=0.0)(default=1e-12) + classifier_dropout_prob: float = probability(default=0.1) position_embedding_type: Literal["absolute", "relative_key", "relative_key_query"] = "absolute" - pad_token_id: Optional[int] = validated_field(token, default=0) - bos_token_id: Optional[int] = validated_field(token, default=2) - eos_token_id: Optional[int] = validated_field(token, default=3) + pad_token_id: Optional[int] = token(default=0) + bos_token_id: Optional[int] = token(default=2) + eos_token_id: Optional[int] = token(default=3) # Not part of __init__ model_type = "albert" @@ -139,6 +139,9 @@ def __post_init__(self): def validate(self): """Ensures the configuration is valid by assessing combinations of arguments.""" + # Generic config validation + super().validate() + # Architecture validation if self.hidden_size % self.num_attention_heads != 0: raise ValueError( @@ -146,15 +149,6 @@ def validate(self): f"heads ({self.num_attention_heads}" ) - # Token validation - for token_name in ["pad_token_id", "bos_token_id", "eos_token_id"]: - token_id = getattr(self, token_name) - if token_id is not None and not 0 <= token_id < self.vocab_size: - raise ValueError( - f"{token_name} must be in the vocabulary with size {self.vocab_size}, i.e. between 0 and " - f"{self.vocab_size - 1}, got {token_id}." - ) - # Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Roberta->Albert class AlbertOnnxConfig(OnnxConfig): diff --git a/src/transformers/validators.py b/src/transformers/validators.py index 71083a8b0ec4..f74f66a11896 100644 --- a/src/transformers/validators.py +++ b/src/transformers/validators.py @@ -19,6 +19,8 @@ from typing import Callable, Optional +from huggingface_hub.dataclasses import as_validated_field + from .activations import ACT2CLS @@ -63,6 +65,7 @@ def interval( min = min or float("-inf") max = max or float("inf") + @as_validated_field def _inner(value: int | float): min_valid = min <= value if not exclude_min else min < value max_valid = value <= max if not exclude_max else value < max @@ -72,12 +75,14 @@ def _inner(value: int | float): return _inner +@as_validated_field def probability(value: float): """Ensures that `value` is a valid probability number, i.e. [0,1].""" if not 0 <= value <= 1: raise ValueError(f"Value must be a probability between 0.0 and 1.0, got {value}.") +@as_validated_field def token(value: Optional[int]): """Ensures that `value` is a potential token. A token, when set, must be a non-negative integer.""" if value is not None and value < 0: @@ -87,6 +92,7 @@ def token(value: Optional[int]): # String validators +@as_validated_field def activation_fn_key(value: str): """Ensures that `value` is a string corresponding to an activation function.""" # TODO (joao): in python 3.11+, we can build a Literal type from the keys of ACT2CLS diff --git a/tests/utils/test_validators.py b/tests/utils/test_validators.py index b309857797be..10d6f3735280 100644 --- a/tests/utils/test_validators.py +++ b/tests/utils/test_validators.py @@ -12,77 +12,113 @@ # See the License for the specific language governing permissions and # limitations under the License. import unittest +from dataclasses import dataclass +from typing import Optional, Union + +from huggingface_hub.dataclasses import StrictDataclassFieldValidationError, strict from transformers import AlbertConfig -from transformers.validators import interval, probability, token, activation_fn_key -from huggingface_hub.dataclasses import StrictDataclassFieldValidationError +from transformers.validators import activation_fn_key, interval, probability, token class ValidatorsTests(unittest.TestCase): """ - Sanity check tests for the validators. Note that the validators do not perform strict type checking - (`huggingface_hub.dataclasses.strict` is used for that). + Sanity check tests for the validators. Validators are `field` in a dataclass, and not meant to be used on + their own. """ def test_interval(self): + # Setup test dataclasses + @strict + @dataclass + class TestInterval: + data: Union[int, float] = interval(min=1, max=10)() + + @strict + @dataclass + class TestIntervalExcludeMinMax: + data: Union[int, float] = interval(min=1, max=10, exclude_min=True, exclude_max=True)() + # valid - interval(1, 10)(5) - interval(1, 10)(5.0) - interval(1, 10)(10) - interval(1, 10)(1) - interval(1, 10, exclude_min=True)(1.0000001) + TestInterval(5) + TestInterval(5.0) + TestInterval(10) + TestInterval(1) + TestIntervalExcludeMinMax(1.0000001) # invalid - with self.assertRaises(ValueError): - interval(1, 10)(11) # greater than max - with self.assertRaises(ValueError): - interval(1, 10)(0.9999999) # less than min - with self.assertRaises(ValueError): - interval(1, 10, exclude_max=True)(10) # equal to max, but exclude_max is True - with self.assertRaises(ValueError): - interval(1, 10, exclude_min=True)(1.0) # equal to min, but exclude_min is True - with self.assertRaises(ValueError): - interval(1, 10)(-5) # less than min + with self.assertRaises(StrictDataclassFieldValidationError): + TestInterval("one") # different type + with self.assertRaises(StrictDataclassFieldValidationError): + TestInterval(11) # greater than max + with self.assertRaises(StrictDataclassFieldValidationError): + TestInterval(0.9999999) # less than min + with self.assertRaises(StrictDataclassFieldValidationError): + TestIntervalExcludeMinMax(10) # equal to max, but exclude_max is True + with self.assertRaises(StrictDataclassFieldValidationError): + TestIntervalExcludeMinMax(1.0) # equal to min, but exclude_min is True + with self.assertRaises(StrictDataclassFieldValidationError): + TestInterval(-5) # less than min def test_probability(self): + # Setup test dataclasses + @strict + @dataclass + class TestProbability: + data: float = probability() + # valid - probability(0.5) - probability(0) - probability(1) + TestProbability(0.5) + TestProbability(0.0) + TestProbability(1.0) # invalid - with self.assertRaises(ValueError): - probability(99) # 0-1 probabilities only - with self.assertRaises(ValueError): - probability(1.1) # greater than 1 - with self.assertRaises(ValueError): - probability(-0.1) # less than 0 + with self.assertRaises(StrictDataclassFieldValidationError): + TestProbability(1) # different type + with self.assertRaises(StrictDataclassFieldValidationError): + TestProbability(99.0) # 0-1 probabilities only + with self.assertRaises(StrictDataclassFieldValidationError): + TestProbability(1.1) # greater than 1 + with self.assertRaises(StrictDataclassFieldValidationError): + TestProbability(-0.1) # less than 0 def test_token(self): + # Setup test dataclasses + @strict + @dataclass + class TestToken: + data: Optional[int] = token() + # valid - token(None) - token(0) - token(1) - token(999999999) + TestToken(None) + TestToken(0) + TestToken(1) + TestToken(999999999) # invalid - with self.assertRaises(ValueError): - token(-1) # less than 0 - with self.assertRaises(TypeError): - token("") # must be the token id, not its string counterpart + with self.assertRaises(StrictDataclassFieldValidationError): + TestToken(-1) # less than 0 + with self.assertRaises(StrictDataclassFieldValidationError): + TestToken("") # different type: must be the token id, not its string counterpart def test_activation_fn_key(self): + # Setup test dataclasses + @strict + @dataclass + class TestActivationFnKey: + data: str = activation_fn_key() + # valid - activation_fn_key("relu") - activation_fn_key("gelu") + TestActivationFnKey("relu") + TestActivationFnKey("gelu") # invalid - with self.assertRaises(ValueError): - activation_fn_key("foo") # obvious one - with self.assertRaises(ValueError): - activation_fn_key(None) # can't be None - with self.assertRaises(ValueError): - activation_fn_key("Relu") # typo: should be "relu", not "Relu" + with self.assertRaises(StrictDataclassFieldValidationError): + TestActivationFnKey("foo") # obvious one + with self.assertRaises(StrictDataclassFieldValidationError): + TestActivationFnKey(None) # different type: can't be None + with self.assertRaises(StrictDataclassFieldValidationError): + TestActivationFnKey("Relu") # typo: should be "relu", not "Relu" class ValidatorsIntegrationTests(unittest.TestCase): @@ -121,7 +157,7 @@ def test_model_config_validation(self): with self.assertRaises(StrictDataclassFieldValidationError): config = AlbertConfig(position_embedding_type="foo") - # eos_token_id is a token, and must be non-negative + # eos_token_id is a token, and must be non-negative with self.assertRaises(StrictDataclassFieldValidationError): config = AlbertConfig(eos_token_id=-1) From e73505eba094556f680c2021dbc6b22d1387807b Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 14 May 2025 14:43:44 +0000 Subject: [PATCH 08/18] validate functions --- src/transformers/configuration_utils.py | 7 ++----- .../models/albert/configuration_albert.py | 12 ++---------- tests/utils/test_validators.py | 9 ++++++++- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 5ad5c822d86e..c382e87378ec 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -382,11 +382,8 @@ def _attn_implementation(self): def _attn_implementation(self, value): self._attn_implementation_internal = value - def validate(self): - """ - Validates the contents of the config. - """ - # Special token validation + def validate_token_ids(self): + """Part of `@strict`-powered validation. Validates the contents of the special tokens.""" text_config = self.get_text_config() vocab_size = getattr(text_config, "vocab_size", None) if vocab_size is not None: diff --git a/src/transformers/models/albert/configuration_albert.py b/src/transformers/models/albert/configuration_albert.py index 4f612ee197dd..eaf33353b0be 100644 --- a/src/transformers/models/albert/configuration_albert.py +++ b/src/transformers/models/albert/configuration_albert.py @@ -133,16 +133,8 @@ class AlbertConfig(PretrainedConfig): # Not part of __init__ model_type = "albert" - def __post_init__(self): - """Called after `__init__`: validates the instance.""" - self.validate() - - def validate(self): - """Ensures the configuration is valid by assessing combinations of arguments.""" - # Generic config validation - super().validate() - - # Architecture validation + def validate_architecture(self): + """Part of `@strict`-powered validation. Validates the architecture of the config.""" if self.hidden_size % self.num_attention_heads != 0: raise ValueError( f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention " diff --git a/tests/utils/test_validators.py b/tests/utils/test_validators.py index 10d6f3735280..8710ad4a523c 100644 --- a/tests/utils/test_validators.py +++ b/tests/utils/test_validators.py @@ -161,10 +161,17 @@ def test_model_config_validation(self): with self.assertRaises(StrictDataclassFieldValidationError): config = AlbertConfig(eos_token_id=-1) - # `validate()` is called in `__post_init__`, i.e. after `__init__`. A special token must be in the vocabulary. + # `@strict` calls `validate()` in `__post_init__`, i.e. after `__init__`. All functions defined as + # `validate_XXX(self)` will be called as part of the validation process. In this case, a special token must + # be in the vocabulary, and the validation function is defined in the base config class. with self.assertRaises(ValueError): config = AlbertConfig(vocab_size=10, eos_token_id=99) + # Similar to the previous case, but the validation function is defined in the model config class. The hidden + # size must be divisible by the number of attention heads. + with self.assertRaises(ValueError): + config = AlbertConfig(hidden_size=10, num_attention_heads=3) + # vocab size is assigned after init, individual attributes are checked on assignment with self.assertRaises(StrictDataclassFieldValidationError): config = AlbertConfig() From 9016a8030d0e79a8e21885944f2e651444543de4 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 14 May 2025 15:25:59 +0000 Subject: [PATCH 09/18] update thrown exceptions in test --- tests/utils/test_validators.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/utils/test_validators.py b/tests/utils/test_validators.py index 8710ad4a523c..16f5ce138bed 100644 --- a/tests/utils/test_validators.py +++ b/tests/utils/test_validators.py @@ -15,7 +15,11 @@ from dataclasses import dataclass from typing import Optional, Union -from huggingface_hub.dataclasses import StrictDataclassFieldValidationError, strict +from huggingface_hub.dataclasses import ( + StrictDataclassClassValidationError, + StrictDataclassFieldValidationError, + strict, +) from transformers import AlbertConfig from transformers.validators import activation_fn_key, interval, probability, token @@ -138,7 +142,7 @@ def test_model_config_validation(self): # 2 - Manual specification, traveling through an invalid config, should be allowed config.eos_token_id = 99 # vocab_size = 30000, eos_token_id = 99 -> valid config.vocab_size = 10 # vocab_size = 10, eos_token_id = 99 -> invalid (but only throws error in `validate()`) - with self.assertRaises(ValueError): + with self.assertRaises(StrictDataclassClassValidationError): config.validate() config.eos_token_id = 9 # vocab_size = 10, eos_token_id = 9 -> valid config.validate() @@ -164,12 +168,12 @@ def test_model_config_validation(self): # `@strict` calls `validate()` in `__post_init__`, i.e. after `__init__`. All functions defined as # `validate_XXX(self)` will be called as part of the validation process. In this case, a special token must # be in the vocabulary, and the validation function is defined in the base config class. - with self.assertRaises(ValueError): + with self.assertRaises(StrictDataclassClassValidationError): config = AlbertConfig(vocab_size=10, eos_token_id=99) # Similar to the previous case, but the validation function is defined in the model config class. The hidden # size must be divisible by the number of attention heads. - with self.assertRaises(ValueError): + with self.assertRaises(StrictDataclassClassValidationError): config = AlbertConfig(hidden_size=10, num_attention_heads=3) # vocab size is assigned after init, individual attributes are checked on assignment From e1c6e898a0cd3bc836cccdf409ffba29c0d3eb1c Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 15 May 2025 16:29:21 +0100 Subject: [PATCH 10/18] Update src/transformers/validators.py Co-authored-by: Lucain --- src/transformers/validators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/validators.py b/src/transformers/validators.py index f74f66a11896..1cd835e0da11 100644 --- a/src/transformers/validators.py +++ b/src/transformers/validators.py @@ -35,7 +35,7 @@ def interval( ) -> Callable: """ Parameterized validator that ensures that `value` is within the defined interval. Optionally, the interval can be - open on either side. Expected usage: `validated_field(interval(min=0), default=8)` + open on either side. Expected usage: `interval(min=0)(default=8)` Args: min (`int` or `float`, *optional*): From bf3944b4f471746c88f09e73679c49eab2d19af1 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 29 May 2025 16:20:51 +0000 Subject: [PATCH 11/18] lower severity of token checks (exception -> warning) --- setup.py | 2 +- src/transformers/configuration_utils.py | 6 +++--- src/transformers/dependency_versions_table.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 2b74308081ef..cb3e90adcd8d 100644 --- a/setup.py +++ b/setup.py @@ -117,7 +117,7 @@ "GitPython<3.1.19", "hf-doc-builder>=0.3.0", "hf_xet", - "huggingface-hub>=0.30.0,<1.0", + "huggingface-hub>=0.31.4,<1.0", "importlib_metadata", "ipadic>=1.0.0,<2.0", "isort>=5.5.4", diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index c382e87378ec..6e42c42d50ff 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -390,9 +390,9 @@ def validate_token_ids(self): for token_name in ["pad_token_id", "bos_token_id", "eos_token_id"]: token_id = getattr(text_config, token_name, None) if token_id is not None and not 0 <= token_id < vocab_size: - raise ValueError( - f"{token_name} must be `None` or an integer within the vocabulary, " - f"i.e. between 0 and {vocab_size - 1}, got {token_id}." + logger.warning_once( + f"{token_name} must be `None` or an integer within the vocabulary (between 0 and " + f"{vocab_size - 1}), got {token_id}. This may lead to unexpected behavior." ) def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 5c0ae6b772f3..625fd04eee47 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -24,7 +24,7 @@ "GitPython": "GitPython<3.1.19", "hf-doc-builder": "hf-doc-builder>=0.3.0", "hf_xet": "hf_xet", - "huggingface-hub": "huggingface-hub>=0.30.0,<1.0", + "huggingface-hub": "huggingface-hub>=0.31.4,<1.0", "importlib_metadata": "importlib_metadata", "ipadic": "ipadic>=1.0.0,<2.0", "isort": "isort>=5.5.4", From 47478c80d195a33224fe70d55714497cdd59bd07 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 29 May 2025 18:06:23 +0000 Subject: [PATCH 12/18] import guards; reduce token checks; validate on save --- src/transformers/configuration_utils.py | 10 +++- .../models/albert/configuration_albert.py | 16 +++--- src/transformers/validators.py | 30 +++++------ tests/utils/test_validators.py | 54 +++++++++---------- 4 files changed, 55 insertions(+), 55 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 6e42c42d50ff..03f55f2a99e6 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -390,9 +390,11 @@ def validate_token_ids(self): for token_name in ["pad_token_id", "bos_token_id", "eos_token_id"]: token_id = getattr(text_config, token_name, None) if token_id is not None and not 0 <= token_id < vocab_size: + # Can't be an exception until we can load configs that fail validation: several configs on the Hub + # store invalid special tokens, e.g. `pad_token_id=-1` logger.warning_once( - f"{token_name} must be `None` or an integer within the vocabulary (between 0 and " - f"{vocab_size - 1}), got {token_id}. This may lead to unexpected behavior." + f"Model config: {token_name} must be `None` or an integer within the vocabulary (between 0 " + f"and {vocab_size - 1}), got {token_id}. This may result in unexpected behavior." ) def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): @@ -427,6 +429,10 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: UserWarning, ) + # Strict validation at save-time: prevent bad patterns from propagating + if hasattr(self, "validate"): + self.validate() + os.makedirs(save_directory, exist_ok=True) if push_to_hub: diff --git a/src/transformers/models/albert/configuration_albert.py b/src/transformers/models/albert/configuration_albert.py index eaf33353b0be..6486484d3c60 100644 --- a/src/transformers/models/albert/configuration_albert.py +++ b/src/transformers/models/albert/configuration_albert.py @@ -17,13 +17,13 @@ from collections import OrderedDict from dataclasses import dataclass -from typing import Literal, Mapping, Optional +from typing import Literal, Mapping, Optional, Union from huggingface_hub.dataclasses import strict from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfig -from ...validators import activation_fn_key, interval, probability, token +from ...validators import activation_fn_key, interval, probability @strict(accept_kwargs=True) @@ -118,17 +118,17 @@ class AlbertConfig(PretrainedConfig): intermediate_size: int = interval(min=1)(default=16384) inner_group_num: int = interval(min=0)(default=1) hidden_act: str = activation_fn_key(default="gelu_new") - hidden_dropout_prob: float = probability(default=0.0) - attention_probs_dropout_prob: float = probability(default=0.0) + hidden_dropout_prob: Union[float, int] = probability(default=0.0) + attention_probs_dropout_prob: Union[float, int] = probability(default=0.0) max_position_embeddings: int = interval(min=0)(default=512) type_vocab_size: int = interval(min=1)(default=2) initializer_range: float = interval(min=0.0)(default=0.02) layer_norm_eps: float = interval(min=0.0)(default=1e-12) classifier_dropout_prob: float = probability(default=0.1) position_embedding_type: Literal["absolute", "relative_key", "relative_key_query"] = "absolute" - pad_token_id: Optional[int] = token(default=0) - bos_token_id: Optional[int] = token(default=2) - eos_token_id: Optional[int] = token(default=3) + pad_token_id: Optional[int] = 0 + bos_token_id: Optional[int] = 2 + eos_token_id: Optional[int] = 3 # Not part of __init__ model_type = "albert" @@ -138,7 +138,7 @@ def validate_architecture(self): if self.hidden_size % self.num_attention_heads != 0: raise ValueError( f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention " - f"heads ({self.num_attention_heads}" + f"heads ({self.num_attention_heads})." ) diff --git a/src/transformers/validators.py b/src/transformers/validators.py index 1cd835e0da11..a6cf132577a3 100644 --- a/src/transformers/validators.py +++ b/src/transformers/validators.py @@ -21,7 +21,13 @@ from huggingface_hub.dataclasses import as_validated_field -from .activations import ACT2CLS +from .utils import is_torch_available + + +if is_torch_available(): + from .activations import ACT2FN +else: + ACT2FN = {} # Numerical validators @@ -82,25 +88,19 @@ def probability(value: float): raise ValueError(f"Value must be a probability between 0.0 and 1.0, got {value}.") -@as_validated_field -def token(value: Optional[int]): - """Ensures that `value` is a potential token. A token, when set, must be a non-negative integer.""" - if value is not None and value < 0: - raise ValueError(f"A token, when set, must be a non-negative integer, got {value}.") - - # String validators @as_validated_field def activation_fn_key(value: str): """Ensures that `value` is a string corresponding to an activation function.""" - # TODO (joao): in python 3.11+, we can build a Literal type from the keys of ACT2CLS - if value not in ACT2CLS: - raise ValueError( - f"Value must be one of {list(ACT2CLS.keys())}, got {value}. " - "Make sure to use a string that corresponds to an activation function." - ) + # TODO (joao): in python 3.11+, we can build a Literal type from the keys of ACT2FN + if len(ACT2FN) > 0: # don't validate if we can't import ACT2FN + if value not in ACT2FN: + raise ValueError( + f"Value must be one of {list(ACT2FN.keys())}, got {value}. " + "Make sure to use a string that corresponds to an activation function." + ) -__all__ = ["interval", "probability", "token", "activation_fn_key"] +__all__ = ["interval", "probability", "activation_fn_key"] diff --git a/tests/utils/test_validators.py b/tests/utils/test_validators.py index 16f5ce138bed..58f352e2228d 100644 --- a/tests/utils/test_validators.py +++ b/tests/utils/test_validators.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import tempfile import unittest from dataclasses import dataclass -from typing import Optional, Union +from typing import Union from huggingface_hub.dataclasses import ( StrictDataclassClassValidationError, @@ -21,8 +22,9 @@ strict, ) -from transformers import AlbertConfig -from transformers.validators import activation_fn_key, interval, probability, token +from transformers import AlbertConfig, logging +from transformers.testing_utils import CaptureLogger +from transformers.validators import activation_fn_key, interval, probability class ValidatorsTests(unittest.TestCase): @@ -86,25 +88,6 @@ class TestProbability: with self.assertRaises(StrictDataclassFieldValidationError): TestProbability(-0.1) # less than 0 - def test_token(self): - # Setup test dataclasses - @strict - @dataclass - class TestToken: - data: Optional[int] = token() - - # valid - TestToken(None) - TestToken(0) - TestToken(1) - TestToken(999999999) - - # invalid - with self.assertRaises(StrictDataclassFieldValidationError): - TestToken(-1) # less than 0 - with self.assertRaises(StrictDataclassFieldValidationError): - TestToken("") # different type: must be the token id, not its string counterpart - def test_activation_fn_key(self): # Setup test dataclasses @strict @@ -140,11 +123,10 @@ def test_model_config_validation(self): self.assertEqual(config.foo, "bar") # 2 - Manual specification, traveling through an invalid config, should be allowed - config.eos_token_id = 99 # vocab_size = 30000, eos_token_id = 99 -> valid - config.vocab_size = 10 # vocab_size = 10, eos_token_id = 99 -> invalid (but only throws error in `validate()`) + config.hidden_size = 65 # breaks class-wide validation, see `AlbertConfig.validate_architecture` with self.assertRaises(StrictDataclassClassValidationError): config.validate() - config.eos_token_id = 9 # vocab_size = 10, eos_token_id = 9 -> valid + config.num_attention_heads = 5 # 65 % 5 = 0 -> valid config.validate() # 3 - These cases should raise an error @@ -161,15 +143,13 @@ def test_model_config_validation(self): with self.assertRaises(StrictDataclassFieldValidationError): config = AlbertConfig(position_embedding_type="foo") - # eos_token_id is a token, and must be non-negative - with self.assertRaises(StrictDataclassFieldValidationError): - config = AlbertConfig(eos_token_id=-1) - # `@strict` calls `validate()` in `__post_init__`, i.e. after `__init__`. All functions defined as # `validate_XXX(self)` will be called as part of the validation process. In this case, a special token must # be in the vocabulary, and the validation function is defined in the base config class. - with self.assertRaises(StrictDataclassClassValidationError): + logger = logging.get_logger("transformers.configuration_utils") + with CaptureLogger(logger) as captured_logs: config = AlbertConfig(vocab_size=10, eos_token_id=99) + self.assertIn("eos_token_id must be `None` or an integer within the vocabulary", captured_logs.out) # Similar to the previous case, but the validation function is defined in the model config class. The hidden # size must be divisible by the number of attention heads. @@ -180,3 +160,17 @@ def test_model_config_validation(self): with self.assertRaises(StrictDataclassFieldValidationError): config = AlbertConfig() config.vocab_size = "foo" + + def test_bad_config_cant_be_saved(self): + """Test that a bad config can't be saved""" + # 1 - create a good config, modify it so it fails class-wide validation + config = AlbertConfig() + config.validate() + config.hidden_size = 65 # breaks class-wide validation, see `AlbertConfig.validate_architecture` + + # 2 - try to save it, and check that the error message is correct + with tempfile.TemporaryDirectory() as tmp_dir: + with self.assertRaises(StrictDataclassClassValidationError) as exc: + config.save_pretrained(tmp_dir) + # start of the message in `AlbertConfig.validate_architecture` + self.assertTrue("The hidden size " in str(exc.exception)) From 29a4f1ba6d96eb304762cf26c432c6883b56a0d7 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 29 May 2025 18:10:27 +0000 Subject: [PATCH 13/18] type hints --- src/transformers/validators.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/validators.py b/src/transformers/validators.py index a6cf132577a3..1d63b896344a 100644 --- a/src/transformers/validators.py +++ b/src/transformers/validators.py @@ -17,7 +17,7 @@ describe the constraints of your dataclass fields, for the best user experience (e.g. better error messages). """ -from typing import Callable, Optional +from typing import Callable, Optional, Union from huggingface_hub.dataclasses import as_validated_field @@ -34,8 +34,8 @@ def interval( - min: Optional[int | float] = None, - max: Optional[int | float] = None, + min: Optional[Union[int, float]] = None, + max: Optional[Union[int, float]] = None, exclude_min: bool = False, exclude_max: bool = False, ) -> Callable: From a3824141717e749161739a56e9e9fffb93bf5364 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 29 May 2025 18:12:32 +0000 Subject: [PATCH 14/18] type hints --- src/transformers/validators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/validators.py b/src/transformers/validators.py index 1d63b896344a..3373619bc358 100644 --- a/src/transformers/validators.py +++ b/src/transformers/validators.py @@ -72,7 +72,7 @@ def interval( max = max or float("inf") @as_validated_field - def _inner(value: int | float): + def _inner(value: Union[int, float]): min_valid = min <= value if not exclude_min else min < value max_valid = value <= max if not exclude_max else value < max if not (min_valid and max_valid): From 48e9c7133a842ec75ee03be7dc5756bea2f2d06e Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 30 May 2025 15:48:23 +0000 Subject: [PATCH 15/18] maybe like this? --- src/transformers/configuration_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 03f55f2a99e6..b8f9061d91de 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -208,7 +208,13 @@ def __getattribute__(self, key): key = super().__getattribute__("attribute_map")[key] return super().__getattribute__(key) + def __post_init__(self): + self._set_defaults() + def __init__(self, **kwargs): + self._set_defaults(**kwargs) + + def _set_defaults(self, **kwargs): # Attributes with defaults self.return_dict = kwargs.pop("return_dict", True) self.output_hidden_states = kwargs.pop("output_hidden_states", False) From e3c46a7ea7b7d931f59194ee5590bb663d764810 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 30 May 2025 16:03:12 +0000 Subject: [PATCH 16/18] special handling for tokens --- src/transformers/configuration_utils.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index b8f9061d91de..69465a54e8e6 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -215,6 +215,12 @@ def __init__(self, **kwargs): self._set_defaults(**kwargs) def _set_defaults(self, **kwargs): + def _get_default_if_unset(attribute_name, default_value): + set_attribute = getattr(self, attribute_name, None) + if set_attribute is not None: + return set_attribute + return kwargs.pop(attribute_name, default_value) + # Attributes with defaults self.return_dict = kwargs.pop("return_dict", True) self.output_hidden_states = kwargs.pop("output_hidden_states", False) @@ -273,12 +279,11 @@ def _set_defaults(self, **kwargs): # Tokenizer arguments TODO: eventually tokenizer and models should share the same config self.tokenizer_class = kwargs.pop("tokenizer_class", None) self.prefix = kwargs.pop("prefix", None) - self.bos_token_id = kwargs.pop("bos_token_id", None) - self.pad_token_id = kwargs.pop("pad_token_id", None) - self.eos_token_id = kwargs.pop("eos_token_id", None) - self.sep_token_id = kwargs.pop("sep_token_id", None) - - self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) + self.bos_token_id = _get_default_if_unset("bos_token_id", None) + self.pad_token_id = _get_default_if_unset("pad_token_id", None) + self.eos_token_id = _get_default_if_unset("eos_token_id", None) + self.sep_token_id = _get_default_if_unset("sep_token_id", None) + self.decoder_start_token_id = _get_default_if_unset("decoder_start_token_id", None) # task specific arguments self.task_specific_params = kwargs.pop("task_specific_params", None) From 0d942f9cea039e71fb1e23d75cd5816444c25a54 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 30 May 2025 16:45:33 +0000 Subject: [PATCH 17/18] hidden_size value in tests --- tests/models/albert/test_modeling_albert.py | 2 +- tests/test_configuration_common.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/models/albert/test_modeling_albert.py b/tests/models/albert/test_modeling_albert.py index 349a809b07c8..4dddf0ec5b5a 100644 --- a/tests/models/albert/test_modeling_albert.py +++ b/tests/models/albert/test_modeling_albert.py @@ -275,7 +275,7 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def setUp(self): self.model_tester = AlbertModelTester(self) - self.config_tester = ConfigTester(self, config_class=AlbertConfig, hidden_size=37) + self.config_tester = ConfigTester(self, config_class=AlbertConfig, hidden_size=64) def test_config(self): self.config_tester.run_common_tests() diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py index 4d4ce3a3f165..7cd081048f7b 100644 --- a/tests/test_configuration_common.py +++ b/tests/test_configuration_common.py @@ -53,11 +53,12 @@ def create_and_test_config_common_properties(self): self.parent.assertTrue(hasattr(config, prop), msg=f"`{prop}` does not exist") # Test that config has the common properties as setter - for idx, name in enumerate(common_properties): + dummy_value = 64 + for name in common_properties: try: - setattr(config, name, idx) + setattr(config, name, dummy_value) self.parent.assertEqual( - getattr(config, name), idx, msg=f"`{name} value {idx} expected, but was {getattr(config, name)}" + getattr(config, name), dummy_value, msg=f"`{name} value {dummy_value} expected, but was {getattr(config, name)}" ) except NotImplementedError: # Some models might not be able to implement setters for common_properties @@ -65,11 +66,11 @@ def create_and_test_config_common_properties(self): pass # Test if config class can be called with Config(prop_name=..) - for idx, name in enumerate(common_properties): + for name in common_properties: try: - config = self.config_class(**{name: idx}) + config = self.config_class(**{name: dummy_value}) self.parent.assertEqual( - getattr(config, name), idx, msg=f"`{name} value {idx} expected, but was {getattr(config, name)}" + getattr(config, name), dummy_value, msg=f"`{name} value {dummy_value} expected, but was {getattr(config, name)}" ) except NotImplementedError: # Some models might not be able to implement setters for common_properties From 49d8fd26124e4bd883860ae591d841c9d763db9e Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 30 May 2025 17:20:20 +0000 Subject: [PATCH 18/18] attn_implementation getter/setter --- src/transformers/configuration_utils.py | 8 ++++++++ tests/test_configuration_common.py | 8 ++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 69465a54e8e6..4a8b72ee44d4 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -393,6 +393,14 @@ def _attn_implementation(self): def _attn_implementation(self, value): self._attn_implementation_internal = value + @property + def attn_implementation(self): + return self._attn_implementation + + @attn_implementation.setter + def attn_implementation(self, value): + self._attn_implementation = value + def validate_token_ids(self): """Part of `@strict`-powered validation. Validates the contents of the special tokens.""" text_config = self.get_text_config() diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py index 7cd081048f7b..fcf559d8ae43 100644 --- a/tests/test_configuration_common.py +++ b/tests/test_configuration_common.py @@ -58,7 +58,9 @@ def create_and_test_config_common_properties(self): try: setattr(config, name, dummy_value) self.parent.assertEqual( - getattr(config, name), dummy_value, msg=f"`{name} value {dummy_value} expected, but was {getattr(config, name)}" + getattr(config, name), + dummy_value, + msg=f"`{name} value {dummy_value} expected, but was {getattr(config, name)}", ) except NotImplementedError: # Some models might not be able to implement setters for common_properties @@ -70,7 +72,9 @@ def create_and_test_config_common_properties(self): try: config = self.config_class(**{name: dummy_value}) self.parent.assertEqual( - getattr(config, name), dummy_value, msg=f"`{name} value {dummy_value} expected, but was {getattr(config, name)}" + getattr(config, name), + dummy_value, + msg=f"`{name} value {dummy_value} expected, but was {getattr(config, name)}", ) except NotImplementedError: # Some models might not be able to implement setters for common_properties