[Validation] First implementation of @strict from huggingface_hub#36534
[Validation] First implementation of @strict from huggingface_hub#36534gante wants to merge 18 commits intohuggingface:mainfrom
@strict from huggingface_hub#36534Conversation
Wauplin
left a comment
There was a problem hiding this comment.
Thanks for giving it a try! I've left a few comments with some first thoughts
src/transformers/modeling_utils.py
Outdated
| if hasattr(config, "validate"): # e.g. in @strict_dataclass | ||
| config.validate() |
There was a problem hiding this comment.
Not done yet but yes, could be a solution for "class-wide validation" in addition to "per-attribute validation"
| def __post_init__(self): | ||
| self.validate() |
There was a problem hiding this comment.
Typically something we should move to @strict_dataclass definition.
The validate method itself would have to be implemented by each class though
src/transformers/validators.py
Outdated
| def interval(min: Optional[int | float] = None, max: Optional[int | float] = None) -> Callable: | ||
| """ | ||
| Parameterized validator that ensures that `value` is within the defined interval. | ||
| Expected usage: `validated_field(interval(min=0), default=8)` | ||
| """ | ||
| error_message = "Value must be" | ||
| if min is not None: | ||
| error_message += f" at least {min}" | ||
| if min is not None and max is not None: | ||
| error_message += " and" | ||
| if max is not None: | ||
| error_message += f" at most {max}" | ||
| error_message += ", got {value}." | ||
|
|
||
| min = min or float("-inf") | ||
| max = max or float("inf") | ||
|
|
||
| def _inner(value: int | float): | ||
| if not min <= value <= max: | ||
| raise ValueError(error_message.format(value=value)) | ||
|
|
||
| return _inner | ||
|
|
||
|
|
||
| def probability(value: float): | ||
| """Ensures that `value` is a valid probability number, i.e. [0,1].""" | ||
| if not 0 <= value <= 1: | ||
| raise ValueError(f"Value must be a probability between 0.0 and 1.0, got {value}.") |
There was a problem hiding this comment.
pydantic defines conint and confloat which does similar things but in a more generic way. I'd be up to reuse the same naming/interface since that's what people are used to.
- https://docs.pydantic.dev/1.10/usage/types/#arguments-to-conint
- https://docs.pydantic.dev/1.10/usage/types/#arguments-to-confloat
For instance, interval(min=0) is unclear compared to conint(gt=0) or conint(ge=0). A probability would be confloat(ge=0.0, le=0.0), etc.
There was a problem hiding this comment.
I'm open to suggestions, but let me share my thought process :D
Given that we can write arbitrary validators, I would like to avoid that particular interface for two reasons:
- It heavily relies on non-obvious shortened names, which are a source of confusion -- and against our philosophy in transformers. My goal here was to write something a user without knowledge of
strict_dataclasscould read: for instance,intervalwith aminargument on anintshould be immediately obvious,conintwithgtrequires some prior knowledge (con?gt?). - Partial redundancy can result in better UX: a probability is the same as a
floatbetween0.0and1.0, but we can write a more precise error message for the user. Take the case ofdropout-related variables -- while it is technically a probability, it is (was?) a bad practice to set it to a value larger than 0.5, and we could be interested in throwing a warning in that case.
There was a problem hiding this comment.
Good points here! Ok to keep more explicit and redundant APIS then. I just want things to be precise when it comes to intervals (does "between 0.0 and 1.0" means "0.0" and "1.0" included or not?)
|
@Wauplin notwithstanding the syntax for validation in discussion above, the current commit is a nearly working version. If we pass standard arguments, everything happens as expected 🙌 However, I've hit one limitation, and I have one further request:
config = AlbertConfig(foo="bar")
print(config.foo) # should print "bar"I could keep the original
|
Definitely possible yes!
So in the example above, you would have config = AlbertConfig(foo="bar")
print(config.kwargs["foo"]) # should print "bar"?
Feature request accepted 🤗 |
To be fully BC, What I had in mind, to avoid specifying an
and either (option A)
or (option B)
|
|
@Wauplin: Reviving this project. The current issue is that hub would then hold a technically sound class for everyone to use 🤗 |
|
Potentially related: vllm-project/vllm#14764 -> add shape checkers |
|
@gante I reviewed and updated my previous PR. The main change is that instead of the previous
This is now doable yes! @strict(accept_kwargs=True)
@dataclass
class ConfigWithKwargs:
model_type: str
vocab_size: int = 16
config = ConfigWithKwargs(model_type="bert", vocab_size=30000, extra_field="extra_value")
print(config) # ConfigWithKwargs(model_type='bert', vocab_size=30000, *extra_field='extra_value')Check out documentation on https://moon-ci-docs.huggingface.co/docs/huggingface_hub/pr_2895/en/package_reference/dataclasses |
| def __post_init__(self): | ||
| """Called after `__init__`: validates the instance.""" | ||
| self.validate() |
There was a problem hiding this comment.
@Wauplin as you mentioned in a comment in a prior commit: this could be moved to strict (but it's a minor thing, happy to keep in transformers)
There was a problem hiding this comment.
let's have it in huggingface_hub directly!
@strict_dataclass from huggingface_hub@strict from huggingface_hub
| # class-level validation of config (as opposed to the attribute-level validation provided by `@strict`) | ||
| if hasattr(config, "validate"): | ||
| config.validate() |
There was a problem hiding this comment.
validates the config at model init time
ArthurZucker
left a comment
There was a problem hiding this comment.
love it!
WOuld be nice to have:
vocab_size: int = check(int, default=...)
something even more minimal
| # Token validation | ||
| for token_name in ["pad_token_id", "bos_token_id", "eos_token_id"]: | ||
| token_id = getattr(self, token_name) | ||
| if token_id is not None and not 0 <= token_id < self.vocab_size: | ||
| raise ValueError( | ||
| f"{token_name} must be in the vocabulary with size {self.vocab_size}, i.e. between 0 and " | ||
| f"{self.vocab_size - 1}, got {token_id}." | ||
| ) |
There was a problem hiding this comment.
these can be move the general PreTrainedConfig 's validate function for the common args
There was a problem hiding this comment.
Agreed.
And they would have to be lifted to warnings, I'm sure there are config files out there with negative tokens (this was a common trick in the past to e.g. manipulate generation into not stopping on EOS)
| >>> configuration = model.config | ||
| ```""" | ||
|
|
||
| vocab_size: int = validated_field(interval(min=1), default=30000) |
There was a problem hiding this comment.
is there a way to juste write validate or check instead of check_field?
| num_attention_heads: int = validated_field(interval(min=0), default=64) | ||
| intermediate_size: int = validated_field(interval(min=1), default=16384) | ||
| inner_group_num: int = validated_field(interval(min=0), default=1) | ||
| hidden_act: str = validated_field(activation_fn_key, default="gelu_new") |
There was a problem hiding this comment.
if activation_fn_key is a key, I would rather use ActivationFnKey as a class
| hidden_act: str = validated_field(activation_fn_key, default="gelu_new") | |
| hidden_act: str = check_activation_function(default="gelu_new") |
There was a problem hiding this comment.
In python 3.11, we can define something like ActivationFnKey = Literal[possible values], and then simply have
hidden_act: ActivationFnKey = "gelu_new"Until then, I think it's best to keep the same pattern as in other validated fields, i.e. validated_field(validation_fn, default). WDYT?
| embedding_size: int = validated_field(interval(min=1), default=128) | ||
| hidden_size: int = validated_field(interval(min=1), default=4096) | ||
| num_hidden_layers: int = validated_field(interval(min=1), default=12) | ||
| num_hidden_groups: int = validated_field(interval(min=1), default=1) | ||
| num_attention_heads: int = validated_field(interval(min=0), default=64) | ||
| intermediate_size: int = validated_field(interval(min=1), default=16384) | ||
| inner_group_num: int = validated_field(interval(min=0), default=1) |
There was a problem hiding this comment.
IMO the interval are too much, we only need the default here!
There was a problem hiding this comment.
Would you prefer something like
embedding_size: int = validated_field(strictly_non_negative, default=128)or simply
embedding_size: int = 128?
IMO, since most people don't read the config classes, I think validation is a nice sanity check 🤗
There was a problem hiding this comment.
@ArthurZucker @gante I've pushed an update that should greatly reduce the verbosity of validated fields. You can use the @as_validated_field decorator when defining your validators which avoid using validated_field on each field. Here is an example:
from dataclasses import dataclass
from typing import Optional
from huggingface_hub.dataclasses import as_validated_field, strict, validated_field
# Case 1: decorated validator with no arguments
@as_validated_field
def probability(value: int):
if not 0 <= value <= 1:
raise ValueError(f"Value must be in interval [0, 1], got {value}")
# Case 2: decorated validator with arguments
def interval(min: Optional[int] = None, max: Optional[int] = None):
@as_validated_field
def _inner(value: int) -> None:
if min is not None and value < min:
raise ValueError(f"Value must be greater than {min}, got {value}")
if max is not None and value > max:
raise ValueError(f"Value must be less than {max}, got {value}")
return _inner
# Case 3: multiple validators
def positive(value: int) -> None:
if not value >= 0:
raise ValueError(f"Value must be positive, got {value}")
def multiple_of_2(value: int) -> None:
if value % 2 != 0:
raise ValueError(f"Value must be a multiple of 2, got {value}")
@strict
@dataclass
class Config:
# No custom validation (only type checking)
model_type: str
# Validator defined using the decorator
hidden_dropout_prob: float = probability(default=0.0)
# Validator with args defined using the decorator
vocab_size: int = interval(min=10)(default=16)
# Type checking + 2 validators (more verbose but more explicit)
hidden_size: int = validated_field([positive, multiple_of_2], default=32)There was a problem hiding this comment.
With this syntax and we correctly defined validators, the AlbertConfig would become:
@strict(accept_kwargs=True)
@dataclass
class AlbertConfig(PretrainedConfig):
vocab_size: int = interval(min=1)(default=30000)
embedding_size: int = interval(min=1)(default=128)
hidden_size: int = interval(min=1)(default=4096)
num_hidden_layers: int = interval(min=1)(default=12)
num_hidden_groups: int = interval(min=1)(default=1)
num_attention_heads: int = interval(min=0)(default=64)
intermediate_size: int = interval(min=1)(default=16384)
inner_group_num: int = interval(min=0)(default=1)
hidden_act: str = activation_fn_key(default="gelu_new")
hidden_dropout_prob: float = probability(default=0.0)
attention_probs_dropout_prob: float = probability(default=0.0)
max_position_embeddings: int = interval(min=0)(default=512)
type_vocab_size: int = interval(min=1)(default=2)
initializer_range: float = interval(min=0.0)(default=0.02)
layer_norm_eps: float = interval(min=0.0)(default=1e-12)
classifier_dropout_prob: float = probability(default=0.1)
position_embedding_type: Literal["absolute", "relative_key", "relative_key_query"] = "absolute"
pad_token_id: Optional[int] = token(default=0)
bos_token_id: Optional[int] = token(default=2)
eos_token_id: Optional[int] = token(default=3)which can hardly be less verbose IMO
|
@Wauplin PR updated with the latest syntax 🙌 Moving @ArthurZucker syntax is much shorter now 🤗 |
🎉
huggingface/huggingface_hub@27ca2d7 🫶 with docs in https://moon-ci-docs.huggingface.co/docs/huggingface_hub/pr_2895/en/package_reference/dataclasses#class-validators |
|
With class validators, you can now do that: @strict(accept_kwargs=True)
@dataclass
class AlbertConfig(PretrainedConfig):
vocab_size: int = interval(min=1)(default=30000)
embedding_size: int = interval(min=1)(default=128)
hidden_size: int = interval(min=1)(default=4096)
num_hidden_layers: int = interval(min=1)(default=12)
num_hidden_groups: int = interval(min=1)(default=1)
num_attention_heads: int = interval(min=0)(default=64)
intermediate_size: int = interval(min=1)(default=16384)
inner_group_num: int = interval(min=0)(default=1)
hidden_act: str = activation_fn_key(default="gelu_new")
hidden_dropout_prob: float = probability(default=0.0)
attention_probs_dropout_prob: float = probability(default=0.0)
max_position_embeddings: int = interval(min=0)(default=512)
type_vocab_size: int = interval(min=1)(default=2)
initializer_range: float = interval(min=0.0)(default=0.02)
layer_norm_eps: float = interval(min=0.0)(default=1e-12)
classifier_dropout_prob: float = probability(default=0.1)
position_embedding_type: Literal["absolute", "relative_key", "relative_key_query"] = "absolute"
pad_token_id: Optional[int] = token(default=0)
bos_token_id: Optional[int] = token(default=2)
eos_token_id: Optional[int] = token(default=3)
# Not part of __init__
model_type = "albert"
def validate_architecture(self):
"""Validates the architecture of the model."""
# Check if the number of attention heads is a divisor of the hidden size
if self.hidden_size % self.num_attention_heads != 0:
raise ValueError(
f"The hidden size ({self.hidden_size}) must be divisible by the number of attention "
f"heads ({self.num_attention_heads})."
)=> Documentation: https://moon-ci-docs.huggingface.co/docs/huggingface_hub/pr_2895/en/package_reference/dataclasses#class-validators |
| # class-level validation of config (as opposed to the attribute-level validation provided by `@strict`) | ||
| if hasattr(config, "validate"): | ||
| config.validate() | ||
|
|
There was a problem hiding this comment.
| # class-level validation of config (as opposed to the attribute-level validation provided by `@strict`) | |
| if hasattr(config, "validate"): | |
| config.validate() |
not needed anymore
There was a problem hiding this comment.
It's possible to load a config, modify it to an invalid state, and then use that config to instantiate a model :( as such, the model should ensure the config is valid before committing resources
Co-authored-by: Lucain <lucainp@gmail.com>
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
In favor of removing interval when its obvious!
| vocab_size: int = interval(min=1)(default=30000) | ||
| embedding_size: int = interval(min=1)(default=128) | ||
| hidden_size: int = interval(min=1)(default=4096) | ||
| num_hidden_layers: int = interval(min=1)(default=12) | ||
| num_hidden_groups: int = interval(min=1)(default=1) | ||
| num_attention_heads: int = interval(min=0)(default=64) | ||
| intermediate_size: int = interval(min=1)(default=16384) | ||
| inner_group_num: int = interval(min=0)(default=1) | ||
| hidden_act: str = activation_fn_key(default="gelu_new") |
There was a problem hiding this comment.
BTW I think we should only use default here. let's be less verbose when not needed! Not even sure we need to specify this being an interval!
|
Note to myself before I go on holidays Now that Let's ignore the validation logic itself, which is working well, and focus on the pressing issue: handling inheritance of Issue: Base class inheritanceWe are defining a model-level
|
|
This is excellent work, it will be amazing to have 🚀 super looking forward! |
|
(If someone stumbles upon this in the future: @zucchini-nlp is working on a better |
What does this PR do?
Testbed for huggingface/huggingface_hub#2895, released recently.
Core ideas:
__init__and assignment time, using strict type checks and custom value validators;__init__throughdataclass's__post_init__, or manually through.validate()(.validate()must be called to validate at a class-level). With this, we can check e.g. if a special token is within the vocabulary range;validate_xxxmethods:@strictadds them to its validation functions.Minimal test script: