-
Notifications
You must be signed in to change notification settings - Fork 32.7k
[Validation] First implementation of @strict from huggingface_hub
#36534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4b3a954
62b3e12
654845e
6318d84
7cc2a9a
0c7c4c5
4b14696
e73505e
9016a80
e1c6e89
bf3944b
47478c8
29a4f1b
a382414
48e9c71
e3c46a7
0d942f9
49d8fd2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -2025,6 +2025,10 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): | |||||||
| "`PretrainedConfig`. To create a model from a pretrained model use " | ||||||||
| f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`" | ||||||||
| ) | ||||||||
| # class-level validation of config (as opposed to the attribute-level validation provided by `@strict`) | ||||||||
| if hasattr(config, "validate"): | ||||||||
| config.validate() | ||||||||
|
|
||||||||
|
Comment on lines
+2028
to
+2031
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
not needed anymore
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's possible to load a config, modify it to an invalid state, and then use that config to instantiate a model :( as such, the model should ensure the config is valid before committing resources |
||||||||
| if not getattr(config, "_attn_implementation_autoset", False): | ||||||||
| # config usually has a `torch_dtype` but we need the next line for the `no_super_init` tests | ||||||||
| dtype = config.torch_dtype if hasattr(config, "torch_dtype") else torch.get_default_dtype() | ||||||||
|
|
||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,12 +16,18 @@ | |
| """ALBERT model configuration""" | ||
|
|
||
| from collections import OrderedDict | ||
| from typing import Mapping | ||
| from dataclasses import dataclass | ||
| from typing import Literal, Mapping, Optional, Union | ||
|
|
||
| from huggingface_hub.dataclasses import strict | ||
|
|
||
| from ...configuration_utils import PretrainedConfig | ||
| from ...onnx import OnnxConfig | ||
| from ...validators import activation_fn_key, interval, probability | ||
|
|
||
|
|
||
| @strict(accept_kwargs=True) | ||
| @dataclass | ||
| class AlbertConfig(PretrainedConfig): | ||
| r""" | ||
| This is the configuration class to store the configuration of a [`AlbertModel`] or a [`TFAlbertModel`]. It is used | ||
|
|
@@ -53,9 +59,9 @@ class AlbertConfig(PretrainedConfig): | |
| hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu_new"`): | ||
| The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, | ||
| `"relu"`, `"silu"` and `"gelu_new"` are supported. | ||
| hidden_dropout_prob (`float`, *optional*, defaults to 0): | ||
| hidden_dropout_prob (`float`, *optional*, defaults to 0.0): | ||
| The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. | ||
| attention_probs_dropout_prob (`float`, *optional*, defaults to 0): | ||
| attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): | ||
| The dropout ratio for the attention probabilities. | ||
| max_position_embeddings (`int`, *optional*, defaults to 512): | ||
| The maximum sequence length that this model might ever be used with. Typically set this to something large | ||
|
|
@@ -103,51 +109,37 @@ class AlbertConfig(PretrainedConfig): | |
| >>> configuration = model.config | ||
| ```""" | ||
|
|
||
| vocab_size: int = interval(min=1)(default=30000) | ||
| embedding_size: int = interval(min=1)(default=128) | ||
| hidden_size: int = interval(min=1)(default=4096) | ||
| num_hidden_layers: int = interval(min=1)(default=12) | ||
| num_hidden_groups: int = interval(min=1)(default=1) | ||
| num_attention_heads: int = interval(min=0)(default=64) | ||
| intermediate_size: int = interval(min=1)(default=16384) | ||
| inner_group_num: int = interval(min=0)(default=1) | ||
| hidden_act: str = activation_fn_key(default="gelu_new") | ||
|
Comment on lines
+112
to
+120
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW I think we should only use default here. let's be less verbose when not needed! Not even sure we need to specify this being an interval! |
||
| hidden_dropout_prob: Union[float, int] = probability(default=0.0) | ||
| attention_probs_dropout_prob: Union[float, int] = probability(default=0.0) | ||
| max_position_embeddings: int = interval(min=0)(default=512) | ||
| type_vocab_size: int = interval(min=1)(default=2) | ||
| initializer_range: float = interval(min=0.0)(default=0.02) | ||
| layer_norm_eps: float = interval(min=0.0)(default=1e-12) | ||
| classifier_dropout_prob: float = probability(default=0.1) | ||
| position_embedding_type: Literal["absolute", "relative_key", "relative_key_query"] = "absolute" | ||
| pad_token_id: Optional[int] = 0 | ||
| bos_token_id: Optional[int] = 2 | ||
| eos_token_id: Optional[int] = 3 | ||
|
|
||
| # Not part of __init__ | ||
| model_type = "albert" | ||
|
|
||
| def __init__( | ||
| self, | ||
| vocab_size=30000, | ||
| embedding_size=128, | ||
| hidden_size=4096, | ||
| num_hidden_layers=12, | ||
| num_hidden_groups=1, | ||
| num_attention_heads=64, | ||
| intermediate_size=16384, | ||
| inner_group_num=1, | ||
| hidden_act="gelu_new", | ||
| hidden_dropout_prob=0, | ||
| attention_probs_dropout_prob=0, | ||
| max_position_embeddings=512, | ||
| type_vocab_size=2, | ||
| initializer_range=0.02, | ||
| layer_norm_eps=1e-12, | ||
| classifier_dropout_prob=0.1, | ||
| position_embedding_type="absolute", | ||
| pad_token_id=0, | ||
| bos_token_id=2, | ||
| eos_token_id=3, | ||
| **kwargs, | ||
| ): | ||
| super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) | ||
|
|
||
| self.vocab_size = vocab_size | ||
| self.embedding_size = embedding_size | ||
| self.hidden_size = hidden_size | ||
| self.num_hidden_layers = num_hidden_layers | ||
| self.num_hidden_groups = num_hidden_groups | ||
| self.num_attention_heads = num_attention_heads | ||
| self.inner_group_num = inner_group_num | ||
| self.hidden_act = hidden_act | ||
| self.intermediate_size = intermediate_size | ||
| self.hidden_dropout_prob = hidden_dropout_prob | ||
| self.attention_probs_dropout_prob = attention_probs_dropout_prob | ||
| self.max_position_embeddings = max_position_embeddings | ||
| self.type_vocab_size = type_vocab_size | ||
| self.initializer_range = initializer_range | ||
| self.layer_norm_eps = layer_norm_eps | ||
| self.classifier_dropout_prob = classifier_dropout_prob | ||
| self.position_embedding_type = position_embedding_type | ||
| def validate_architecture(self): | ||
| """Part of `@strict`-powered validation. Validates the architecture of the config.""" | ||
| if self.hidden_size % self.num_attention_heads != 0: | ||
| raise ValueError( | ||
| f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention " | ||
| f"heads ({self.num_attention_heads})." | ||
| ) | ||
|
|
||
|
|
||
| # Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Roberta->Albert | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| # coding=utf-8 | ||
| # Copyright 2025-present the HuggingFace Inc. team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """ | ||
| Validators to be used with `huggingface_hub.dataclasses.validated_field`. We recommend using the validator(s) that best | ||
| describe the constraints of your dataclass fields, for the best user experience (e.g. better error messages). | ||
| """ | ||
|
|
||
| from typing import Callable, Optional, Union | ||
|
|
||
| from huggingface_hub.dataclasses import as_validated_field | ||
|
|
||
| from .utils import is_torch_available | ||
|
|
||
|
|
||
| if is_torch_available(): | ||
| from .activations import ACT2FN | ||
| else: | ||
| ACT2FN = {} | ||
|
|
||
|
|
||
| # Numerical validators | ||
|
|
||
|
|
||
| def interval( | ||
| min: Optional[Union[int, float]] = None, | ||
| max: Optional[Union[int, float]] = None, | ||
| exclude_min: bool = False, | ||
| exclude_max: bool = False, | ||
| ) -> Callable: | ||
| """ | ||
| Parameterized validator that ensures that `value` is within the defined interval. Optionally, the interval can be | ||
| open on either side. Expected usage: `interval(min=0)(default=8)` | ||
|
|
||
| Args: | ||
| min (`int` or `float`, *optional*): | ||
| Minimum value of the interval. | ||
| max (`int` or `float`, *optional*): | ||
| Maximum value of the interval. | ||
| exclude_min (`bool`, *optional*, defaults to `False`): | ||
| If True, the minimum value is excluded from the interval. | ||
| exclude_max (`bool`, *optional*, defaults to `False`): | ||
| If True, the maximum value is excluded from the interval. | ||
| """ | ||
| error_message = "Value must be" | ||
| if min is not None: | ||
| if exclude_min: | ||
| error_message += f" greater than {min}" | ||
| else: | ||
| error_message += f" greater or equal to {min}" | ||
| if min is not None and max is not None: | ||
| error_message += " and" | ||
| if max is not None: | ||
| if exclude_max: | ||
| error_message += f" smaller than {max}" | ||
| else: | ||
| error_message += f" smaller or equal to {max}" | ||
| error_message += ", got {value}." | ||
|
|
||
| min = min or float("-inf") | ||
| max = max or float("inf") | ||
|
|
||
| @as_validated_field | ||
| def _inner(value: Union[int, float]): | ||
| min_valid = min <= value if not exclude_min else min < value | ||
| max_valid = value <= max if not exclude_max else value < max | ||
| if not (min_valid and max_valid): | ||
| raise ValueError(error_message.format(value=value)) | ||
|
|
||
| return _inner | ||
|
|
||
|
|
||
| @as_validated_field | ||
| def probability(value: float): | ||
| """Ensures that `value` is a valid probability number, i.e. [0,1].""" | ||
| if not 0 <= value <= 1: | ||
| raise ValueError(f"Value must be a probability between 0.0 and 1.0, got {value}.") | ||
|
|
||
|
|
||
| # String validators | ||
|
|
||
|
|
||
| @as_validated_field | ||
| def activation_fn_key(value: str): | ||
| """Ensures that `value` is a string corresponding to an activation function.""" | ||
| # TODO (joao): in python 3.11+, we can build a Literal type from the keys of ACT2FN | ||
| if len(ACT2FN) > 0: # don't validate if we can't import ACT2FN | ||
| if value not in ACT2FN: | ||
| raise ValueError( | ||
| f"Value must be one of {list(ACT2FN.keys())}, got {value}. " | ||
| "Make sure to use a string that corresponds to an activation function." | ||
| ) | ||
|
|
||
|
|
||
| __all__ = ["interval", "probability", "activation_fn_key"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
validates the config at model init time