From b863be7f69358c2d6695918b45a4bd3645e5d8c8 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 5 May 2023 22:33:12 +0100 Subject: [PATCH] RootModel (#592) --- pydantic_core/core_schema.py | 4 + src/validators/model.rs | 151 ++++++++++++++-------- tests/benchmarks/test_micro_benchmarks.py | 26 ++++ tests/validators/test_model_init.py | 43 +++++- tests/validators/test_model_root.py | 127 ++++++++++++++++++ 5 files changed, 296 insertions(+), 55 deletions(-) create mode 100644 tests/validators/test_model_root.py diff --git a/pydantic_core/core_schema.py b/pydantic_core/core_schema.py index 8c3676edf2..008c08be57 100644 --- a/pydantic_core/core_schema.py +++ b/pydantic_core/core_schema.py @@ -2838,6 +2838,7 @@ class ModelSchema(TypedDict, total=False): cls: Required[Type[Any]] schema: Required[CoreSchema] custom_init: bool + root_model: bool post_init: str revalidate_instances: Literal['always', 'never', 'subclass-instances'] # default: 'never' strict: bool @@ -2854,6 +2855,7 @@ def model_schema( schema: CoreSchema, *, custom_init: bool | None = None, + root_model: bool | None = None, post_init: str | None = None, revalidate_instances: Literal['always', 'never', 'subclass-instances'] | None = None, strict: bool | None = None, @@ -2894,6 +2896,7 @@ class MyModel: cls: The class to use for the model schema: The schema to use for the model custom_init: Whether the model has a custom init method + root_model: Whether the model is a `RootModel` post_init: The call after init to use for the model revalidate_instances: whether instances of models and dataclasses (including subclass instances) should re-validate defaults to config.revalidate_instances, else 'never' @@ -2910,6 +2913,7 @@ class MyModel: cls=cls, schema=schema, custom_init=custom_init, + root_model=root_model, post_init=post_init, revalidate_instances=revalidate_instances, strict=strict, diff --git a/src/validators/model.rs b/src/validators/model.rs index a6cca9a308..54e6583dee 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -14,6 +14,7 @@ use crate::recursion_guard::RecursionGuard; use super::function::convert_err; use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +const ROOT_FIELD: &str = "root"; const DUNDER_DICT: &str = "__dict__"; const DUNDER_FIELDS_SET_KEY: &str = "__pydantic_fields_set__"; const DUNDER_MODEL_EXTRA_KEY: &str = "__pydantic_extra__"; @@ -52,9 +53,10 @@ pub struct ModelValidator { validator: Box, class: Py, post_init: Option>, - name: String, frozen: bool, custom_init: bool, + root_model: bool, + name: String, } impl BuildValidator for ModelValidator { @@ -87,11 +89,12 @@ impl BuildValidator for ModelValidator { post_init: schema .get_as::<&str>(intern!(py, "post_init"))? .map(|s| PyString::intern(py, s).into_py(py)), + frozen: schema.get_as(intern!(py, "frozen"))?.unwrap_or(false), + custom_init: schema.get_as(intern!(py, "custom_init"))?.unwrap_or(false), + root_model: schema.get_as(intern!(py, "root_model"))?.unwrap_or(false), // Get the class's `__name__`, not using `class.name()` since it uses `__qualname__` // which is not what we want here name: class.getattr(intern!(py, "__name__"))?.extract()?, - frozen: schema.get_as(intern!(py, "frozen"))?.unwrap_or(false), - custom_init: schema.get_as(intern!(py, "custom_init"))?.unwrap_or(false), } .into()) } @@ -125,28 +128,24 @@ impl Validator for ModelValidator { // mask 0 so JSON is input is never true here if input.input_is_instance(class, 0)? { if self.revalidate.should_revalidate(input, class) { - let fields_set = input.input_get_attr(intern!(py, DUNDER_FIELDS_SET_KEY)).unwrap()?; - - // get dict here so from_attributes logic doesn't apply - let dict = input.input_get_attr(intern!(py, DUNDER_DICT)).unwrap()?; - let model_extra = input.input_get_attr(intern!(py, DUNDER_MODEL_EXTRA_KEY)).unwrap()?; - - let full_model_dict: &PyAny = if model_extra.is_none() { - dict + if self.root_model { + let inner_input: &PyAny = input.input_get_attr(intern!(py, ROOT_FIELD)).unwrap()?; + self.validate_construct(py, inner_input, None, extra, definitions, recursion_guard) } else { - let full_model_dict = dict.downcast::()?.copy()?; - full_model_dict.update(model_extra.downcast()?)?; - full_model_dict - }; - - let output = self - .validator - .validate(py, full_model_dict, extra, definitions, recursion_guard)?; - - let (model_dict, model_extra, _): (&PyAny, &PyAny, &PyAny) = output.extract(py)?; - let instance = self.create_class(model_dict, model_extra, fields_set)?; - - self.call_post_init(py, instance, input, extra) + let fields_set = input.input_get_attr(intern!(py, DUNDER_FIELDS_SET_KEY)).unwrap()?; + // get dict here so from_attributes logic doesn't apply + let dict = input.input_get_attr(intern!(py, DUNDER_DICT)).unwrap()?; + let model_extra = input.input_get_attr(intern!(py, DUNDER_MODEL_EXTRA_KEY)).unwrap()?; + + let inner_input: &PyAny = if model_extra.is_none() { + dict + } else { + let full_model_dict = dict.downcast::()?.copy()?; + full_model_dict.update(model_extra.downcast()?)?; + full_model_dict + }; + self.validate_construct(py, inner_input, Some(fields_set), extra, definitions, recursion_guard) + } } else { Ok(input.to_object(py)) } @@ -158,22 +157,7 @@ impl Validator for ModelValidator { input, )) } else { - if self.custom_init { - // If we wanted, we could introspect the __init__ signature, and store the - // keyword arguments and types, and create a validator for them. - // Perhaps something similar to `validate_call`? Could probably make - // this work with from_attributes, and would essentially allow you to - // handle init vars by adding them to the __init__ signature. - if let Some(kwargs) = input.as_kwargs(py) { - return Ok(self.class.call(py, (), Some(kwargs))?); - } - } - let output = self - .validator - .validate(py, input, extra, definitions, recursion_guard)?; - let (model_dict, model_extra, fields_set): (&PyAny, &PyAny, &PyAny) = output.extract(py)?; - let instance = self.create_class(model_dict, model_extra, fields_set)?; - self.call_post_init(py, instance, input, extra) + self.validate_construct(py, input, None, extra, definitions, recursion_guard) } } @@ -189,9 +173,29 @@ impl Validator for ModelValidator { ) -> ValResult<'data, PyObject> { if self.frozen { return Err(ValError::new(ErrorType::FrozenInstance, field_value)); + } else if self.root_model { + return if field_name != ROOT_FIELD { + Err(ValError::new_with_loc( + ErrorType::NoSuchAttribute { + attribute: field_name.to_string(), + }, + field_value, + field_name.to_string(), + )) + } else { + let field_extra = Extra { + field_name: Some(field_name), + ..*extra + }; + let output = self + .validator + .validate(py, field_value, &field_extra, definitions, recursion_guard)?; + + force_setattr(py, model, intern!(py, ROOT_FIELD), output)?; + Ok(model.into_py(py)) + }; } - let dict_py_str = intern!(py, DUNDER_DICT); - let dict: &PyDict = model.getattr(dict_py_str)?.downcast()?; + let dict: &PyDict = model.getattr(intern!(py, DUNDER_DICT))?.downcast()?; let new_dict = dict.copy()?; new_dict.set_item(field_name, field_value)?; @@ -216,7 +220,7 @@ impl Validator for ModelValidator { } let output = output.to_object(py); - force_setattr(py, model, dict_py_str, output)?; + force_setattr(py, model, intern!(py, DUNDER_DICT), output)?; Ok(model.into_py(py)) } @@ -262,11 +266,61 @@ impl ModelValidator { let output = self .validator .validate(py, input, &new_extra, definitions, recursion_guard)?; - let (model_dict, model_extra, fields_set): (&PyAny, &PyAny, &PyAny) = output.extract(py)?; - set_model_attrs(self_instance, model_dict, model_extra, fields_set)?; + + if self.root_model { + force_setattr(py, self_instance, intern!(py, ROOT_FIELD), output.as_ref(py))?; + } else { + let (model_dict, model_extra, fields_set): (&PyAny, &PyAny, &PyAny) = output.extract(py)?; + set_model_attrs(self_instance, model_dict, model_extra, fields_set)?; + } self.call_post_init(py, self_instance.into_py(py), input, extra) } + fn validate_construct<'s, 'data>( + &'s self, + py: Python<'data>, + input: &'data impl Input<'data>, + existing_fields_set: Option<&'data PyAny>, + extra: &Extra, + definitions: &'data Definitions, + recursion_guard: &'s mut RecursionGuard, + ) -> ValResult<'data, PyObject> { + if self.custom_init { + // If we wanted, we could introspect the __init__ signature, and store the + // keyword arguments and types, and create a validator for them. + // Perhaps something similar to `validate_call`? Could probably make + // this work with from_attributes, and would essentially allow you to + // handle init vars by adding them to the __init__ signature. + if let Some(kwargs) = input.as_kwargs(py) { + return Ok(self.class.call(py, (), Some(kwargs))?); + } + } + + let output = if self.root_model { + let field_extra = Extra { + field_name: Some(ROOT_FIELD), + ..*extra + }; + self.validator + .validate(py, input, &field_extra, definitions, recursion_guard)? + } else { + self.validator + .validate(py, input, extra, definitions, recursion_guard)? + }; + + let instance = create_class(self.class.as_ref(py))?; + let instance_ref = instance.as_ref(py); + + if self.root_model { + force_setattr(py, instance_ref, intern!(py, ROOT_FIELD), output)?; + } else { + let (model_dict, model_extra, val_fields_set): (&PyAny, &PyAny, &PyAny) = output.extract(py)?; + let fields_set = existing_fields_set.unwrap_or(val_fields_set); + set_model_attrs(instance_ref, model_dict, model_extra, fields_set)?; + } + self.call_post_init(py, instance, input, extra) + } + fn call_post_init<'s, 'data>( &'s self, py: Python<'data>, @@ -281,13 +335,6 @@ impl ModelValidator { } Ok(instance) } - - fn create_class(&self, model_dict: &PyAny, model_extra: &PyAny, fields_set: &PyAny) -> PyResult { - let py = model_dict.py(); - let instance = create_class(self.class.as_ref(py))?; - set_model_attrs(instance.as_ref(py), model_dict, model_extra, fields_set)?; - Ok(instance) - } } /// based on the following but with the second argument of new_func set to an empty tuple as required diff --git a/tests/benchmarks/test_micro_benchmarks.py b/tests/benchmarks/test_micro_benchmarks.py index 655f21e3e6..a2bf8b0f43 100644 --- a/tests/benchmarks/test_micro_benchmarks.py +++ b/tests/benchmarks/test_micro_benchmarks.py @@ -1313,3 +1313,29 @@ def test_validate_literal( assert res == expected_val_res benchmark(validator.validate_json, input_json) + + +@pytest.mark.benchmark(group='root_model') +def test_core_root_model(benchmark): + class MyModel: + __slots__ = 'root' + root: List[int] + + v = SchemaValidator( + core_schema.model_schema(MyModel, core_schema.list_schema(core_schema.int_schema()), root_model=True) + ) + assert v.validate_python([1, 2, '3']).root == [1, 2, 3] + input_data = list(range(100)) + benchmark(v.validate_python, input_data) + + +@skip_pydantic +@pytest.mark.benchmark(group='root_model') +def test_v1_root_model(benchmark): + class MyModel(BaseModel): + __root__: List[int] + + assert MyModel.parse_obj([1, 2, '3']).__root__ == [1, 2, 3] + input_data = list(range(100)) + + benchmark(MyModel.parse_obj, input_data) diff --git a/tests/validators/test_model_init.py b/tests/validators/test_model_init.py index b059c381fd..380c81bb23 100644 --- a/tests/validators/test_model_init.py +++ b/tests/validators/test_model_init.py @@ -33,9 +33,9 @@ def test_model_init(): m2 = MyModel() ans = v.validate_python({'field_a': 'test', 'field_b': 12}, self_instance=m2) assert ans == m2 - assert m.field_a == 'test' - assert m.field_b == 12 - assert m.__pydantic_fields_set__ == {'field_a', 'field_b'} + assert ans.field_a == 'test' + assert ans.field_b == 12 + assert ans.__pydantic_fields_set__ == {'field_a', 'field_b'} def test_model_init_nested(): @@ -381,3 +381,40 @@ def __init__(self, **data): ('inner', {'a': 1, 'b': 3}, {'b', 'z'}, {'z': 1}), ('outer', {'a': 2, 'b': IsInstance(ModelInner)}, {'c', 'a', 'b'}, {'c': 1}), ] + + +def test_model_custom_init_revalidate(): + calls = [] + + class Model: + __slots__ = '__dict__', '__pydantic_extra__', '__pydantic_fields_set__' + + def __init__(self, **kwargs): + calls.append(repr(kwargs)) + self.__dict__.update(kwargs) + self.__pydantic_fields_set__ = {'custom'} + self.__pydantic_extra__ = None + + v = SchemaValidator( + core_schema.model_schema( + Model, + core_schema.model_fields_schema({'a': core_schema.model_field(core_schema.int_schema())}), + custom_init=True, + config=dict(revalidate_instances='always'), + ) + ) + + m = v.validate_python({'a': '1'}) + assert isinstance(m, Model) + assert m.a == '1' + assert m.__pydantic_fields_set__ == {'custom'} + assert calls == ["{'a': '1'}"] + m.x = 4 + + m2 = v.validate_python(m) + assert m2 is not m + assert isinstance(m2, Model) + assert m2.a == '1' + assert m2.__dict__ == {'a': '1', 'x': 4} + assert m2.__pydantic_fields_set__ == {'custom'} + assert calls == ["{'a': '1'}", "{'a': '1', 'x': 4}"] diff --git a/tests/validators/test_model_root.py b/tests/validators/test_model_root.py new file mode 100644 index 0000000000..6f21f6bb03 --- /dev/null +++ b/tests/validators/test_model_root.py @@ -0,0 +1,127 @@ +from typing import List + +import pytest + +from pydantic_core import SchemaValidator, ValidationError, core_schema + + +def test_model_root(): + class RootModel: + __slots__ = 'root' + root: List[int] + + v = SchemaValidator( + core_schema.model_schema(RootModel, core_schema.list_schema(core_schema.int_schema()), root_model=True) + ) + assert repr(v).startswith('SchemaValidator(title="RootModel", validator=Model(\n') + + m = v.validate_python([1, 2, '3']) + assert isinstance(m, RootModel) + assert m.root == [1, 2, 3] + assert not hasattr(m, '__dict__') + + m = v.validate_json('[1, 2, "3"]') + assert isinstance(m, RootModel) + assert m.root == [1, 2, 3] + + with pytest.raises(ValidationError) as exc_info: + v.validate_python('wrong') + + # insert_assert(exc_info.value.errors(include_url=False)) + assert exc_info.value.errors(include_url=False) == [ + {'type': 'list_type', 'loc': (), 'msg': 'Input should be a valid list', 'input': 'wrong'} + ] + + +def test_revalidate(): + class RootModel: + __slots__ = 'root' + root: List[int] + + v = SchemaValidator( + core_schema.model_schema( + RootModel, core_schema.list_schema(core_schema.int_schema()), root_model=True, revalidate_instances='always' + ) + ) + m = v.validate_python([1, '2']) + assert isinstance(m, RootModel) + assert m.root == [1, 2] + + m2 = v.validate_python(m) + assert m2 is not m + assert isinstance(m2, RootModel) + assert m2.root == [1, 2] + + +def test_init(): + class RootModel: + __slots__ = 'root' + root: str + + v = SchemaValidator( + core_schema.model_schema(RootModel, core_schema.str_schema(), root_model=True, revalidate_instances='always') + ) + + m = RootModel() + ans = v.validate_python('foobar', self_instance=m) + assert ans is m + assert ans.root == 'foobar' + + +def test_assignment(): + class RootModel: + __slots__ = 'root' + root: str + + v = SchemaValidator(core_schema.model_schema(RootModel, core_schema.str_schema(), root_model=True)) + + m = v.validate_python('foobar') + assert m.root == 'foobar' + + m2 = v.validate_assignment(m, 'root', 'baz') + assert m2 is m + assert m.root == 'baz' + + with pytest.raises(ValidationError) as exc_info: + v.validate_assignment(m, 'different', 'baz') + + # insert_assert(exc_info.value.errors(include_url=False)) + assert exc_info.value.errors(include_url=False) == [ + { + 'type': 'no_such_attribute', + 'loc': ('different',), + 'msg': "Object has no attribute 'different'", + 'input': 'baz', + 'ctx': {'attribute': 'different'}, + } + ] + + +def test_field_function(): + call_infos = [] + + class RootModel: + __slots__ = 'root' + root: str + + def f(input_value: str, info): + call_infos.append(repr(info)) + return input_value + ' validated' + + v = SchemaValidator( + core_schema.model_schema( + RootModel, core_schema.field_after_validator_function(f, core_schema.str_schema()), root_model=True + ) + ) + m = v.validate_python('foobar', context='call 1') + assert isinstance(m, RootModel) + assert m.root == 'foobar validated' + assert call_infos == ["ValidationInfo(config=None, context='call 1', field_name='root')"] + + m2 = v.validate_assignment(m, 'root', 'baz', context='assignment call') + assert m2 is m + assert m.root == 'baz validated' + assert call_infos == [ + "ValidationInfo(config=None, context='call 1', field_name='root')", + "ValidationInfo(config=None, context='assignment call', field_name='root')", + ]