Skip to content

Commit

Permalink
RootModel (pydantic#592)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored May 5, 2023
1 parent 493643e commit b863be7
Show file tree
Hide file tree
Showing 5 changed files with 296 additions and 55 deletions.
4 changes: 4 additions & 0 deletions pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2838,6 +2838,7 @@ class ModelSchema(TypedDict, total=False):
cls: Required[Type[Any]]
schema: Required[CoreSchema]
custom_init: bool
root_model: bool
post_init: str
revalidate_instances: Literal['always', 'never', 'subclass-instances'] # default: 'never'
strict: bool
Expand All @@ -2854,6 +2855,7 @@ def model_schema(
schema: CoreSchema,
*,
custom_init: bool | None = None,
root_model: bool | None = None,
post_init: str | None = None,
revalidate_instances: Literal['always', 'never', 'subclass-instances'] | None = None,
strict: bool | None = None,
Expand Down Expand Up @@ -2894,6 +2896,7 @@ class MyModel:
cls: The class to use for the model
schema: The schema to use for the model
custom_init: Whether the model has a custom init method
root_model: Whether the model is a `RootModel`
post_init: The call after init to use for the model
revalidate_instances: whether instances of models and dataclasses (including subclass instances)
should re-validate defaults to config.revalidate_instances, else 'never'
Expand All @@ -2910,6 +2913,7 @@ class MyModel:
cls=cls,
schema=schema,
custom_init=custom_init,
root_model=root_model,
post_init=post_init,
revalidate_instances=revalidate_instances,
strict=strict,
Expand Down
151 changes: 99 additions & 52 deletions src/validators/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use crate::recursion_guard::RecursionGuard;
use super::function::convert_err;
use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator};

const ROOT_FIELD: &str = "root";
const DUNDER_DICT: &str = "__dict__";
const DUNDER_FIELDS_SET_KEY: &str = "__pydantic_fields_set__";
const DUNDER_MODEL_EXTRA_KEY: &str = "__pydantic_extra__";
Expand Down Expand Up @@ -52,9 +53,10 @@ pub struct ModelValidator {
validator: Box<CombinedValidator>,
class: Py<PyType>,
post_init: Option<Py<PyString>>,
name: String,
frozen: bool,
custom_init: bool,
root_model: bool,
name: String,
}

impl BuildValidator for ModelValidator {
Expand Down Expand Up @@ -87,11 +89,12 @@ impl BuildValidator for ModelValidator {
post_init: schema
.get_as::<&str>(intern!(py, "post_init"))?
.map(|s| PyString::intern(py, s).into_py(py)),
frozen: schema.get_as(intern!(py, "frozen"))?.unwrap_or(false),
custom_init: schema.get_as(intern!(py, "custom_init"))?.unwrap_or(false),
root_model: schema.get_as(intern!(py, "root_model"))?.unwrap_or(false),
// Get the class's `__name__`, not using `class.name()` since it uses `__qualname__`
// which is not what we want here
name: class.getattr(intern!(py, "__name__"))?.extract()?,
frozen: schema.get_as(intern!(py, "frozen"))?.unwrap_or(false),
custom_init: schema.get_as(intern!(py, "custom_init"))?.unwrap_or(false),
}
.into())
}
Expand Down Expand Up @@ -125,28 +128,24 @@ impl Validator for ModelValidator {
// mask 0 so JSON is input is never true here
if input.input_is_instance(class, 0)? {
if self.revalidate.should_revalidate(input, class) {
let fields_set = input.input_get_attr(intern!(py, DUNDER_FIELDS_SET_KEY)).unwrap()?;

// get dict here so from_attributes logic doesn't apply
let dict = input.input_get_attr(intern!(py, DUNDER_DICT)).unwrap()?;
let model_extra = input.input_get_attr(intern!(py, DUNDER_MODEL_EXTRA_KEY)).unwrap()?;

let full_model_dict: &PyAny = if model_extra.is_none() {
dict
if self.root_model {
let inner_input: &PyAny = input.input_get_attr(intern!(py, ROOT_FIELD)).unwrap()?;
self.validate_construct(py, inner_input, None, extra, definitions, recursion_guard)
} else {
let full_model_dict = dict.downcast::<PyDict>()?.copy()?;
full_model_dict.update(model_extra.downcast()?)?;
full_model_dict
};

let output = self
.validator
.validate(py, full_model_dict, extra, definitions, recursion_guard)?;

let (model_dict, model_extra, _): (&PyAny, &PyAny, &PyAny) = output.extract(py)?;
let instance = self.create_class(model_dict, model_extra, fields_set)?;

self.call_post_init(py, instance, input, extra)
let fields_set = input.input_get_attr(intern!(py, DUNDER_FIELDS_SET_KEY)).unwrap()?;
// get dict here so from_attributes logic doesn't apply
let dict = input.input_get_attr(intern!(py, DUNDER_DICT)).unwrap()?;
let model_extra = input.input_get_attr(intern!(py, DUNDER_MODEL_EXTRA_KEY)).unwrap()?;

let inner_input: &PyAny = if model_extra.is_none() {
dict
} else {
let full_model_dict = dict.downcast::<PyDict>()?.copy()?;
full_model_dict.update(model_extra.downcast()?)?;
full_model_dict
};
self.validate_construct(py, inner_input, Some(fields_set), extra, definitions, recursion_guard)
}
} else {
Ok(input.to_object(py))
}
Expand All @@ -158,22 +157,7 @@ impl Validator for ModelValidator {
input,
))
} else {
if self.custom_init {
// If we wanted, we could introspect the __init__ signature, and store the
// keyword arguments and types, and create a validator for them.
// Perhaps something similar to `validate_call`? Could probably make
// this work with from_attributes, and would essentially allow you to
// handle init vars by adding them to the __init__ signature.
if let Some(kwargs) = input.as_kwargs(py) {
return Ok(self.class.call(py, (), Some(kwargs))?);
}
}
let output = self
.validator
.validate(py, input, extra, definitions, recursion_guard)?;
let (model_dict, model_extra, fields_set): (&PyAny, &PyAny, &PyAny) = output.extract(py)?;
let instance = self.create_class(model_dict, model_extra, fields_set)?;
self.call_post_init(py, instance, input, extra)
self.validate_construct(py, input, None, extra, definitions, recursion_guard)
}
}

Expand All @@ -189,9 +173,29 @@ impl Validator for ModelValidator {
) -> ValResult<'data, PyObject> {
if self.frozen {
return Err(ValError::new(ErrorType::FrozenInstance, field_value));
} else if self.root_model {
return if field_name != ROOT_FIELD {
Err(ValError::new_with_loc(
ErrorType::NoSuchAttribute {
attribute: field_name.to_string(),
},
field_value,
field_name.to_string(),
))
} else {
let field_extra = Extra {
field_name: Some(field_name),
..*extra
};
let output = self
.validator
.validate(py, field_value, &field_extra, definitions, recursion_guard)?;

force_setattr(py, model, intern!(py, ROOT_FIELD), output)?;
Ok(model.into_py(py))
};
}
let dict_py_str = intern!(py, DUNDER_DICT);
let dict: &PyDict = model.getattr(dict_py_str)?.downcast()?;
let dict: &PyDict = model.getattr(intern!(py, DUNDER_DICT))?.downcast()?;

let new_dict = dict.copy()?;
new_dict.set_item(field_name, field_value)?;
Expand All @@ -216,7 +220,7 @@ impl Validator for ModelValidator {
}
let output = output.to_object(py);

force_setattr(py, model, dict_py_str, output)?;
force_setattr(py, model, intern!(py, DUNDER_DICT), output)?;
Ok(model.into_py(py))
}

Expand Down Expand Up @@ -262,11 +266,61 @@ impl ModelValidator {
let output = self
.validator
.validate(py, input, &new_extra, definitions, recursion_guard)?;
let (model_dict, model_extra, fields_set): (&PyAny, &PyAny, &PyAny) = output.extract(py)?;
set_model_attrs(self_instance, model_dict, model_extra, fields_set)?;

if self.root_model {
force_setattr(py, self_instance, intern!(py, ROOT_FIELD), output.as_ref(py))?;
} else {
let (model_dict, model_extra, fields_set): (&PyAny, &PyAny, &PyAny) = output.extract(py)?;
set_model_attrs(self_instance, model_dict, model_extra, fields_set)?;
}
self.call_post_init(py, self_instance.into_py(py), input, extra)
}

fn validate_construct<'s, 'data>(
&'s self,
py: Python<'data>,
input: &'data impl Input<'data>,
existing_fields_set: Option<&'data PyAny>,
extra: &Extra,
definitions: &'data Definitions<CombinedValidator>,
recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
if self.custom_init {
// If we wanted, we could introspect the __init__ signature, and store the
// keyword arguments and types, and create a validator for them.
// Perhaps something similar to `validate_call`? Could probably make
// this work with from_attributes, and would essentially allow you to
// handle init vars by adding them to the __init__ signature.
if let Some(kwargs) = input.as_kwargs(py) {
return Ok(self.class.call(py, (), Some(kwargs))?);
}
}

let output = if self.root_model {
let field_extra = Extra {
field_name: Some(ROOT_FIELD),
..*extra
};
self.validator
.validate(py, input, &field_extra, definitions, recursion_guard)?
} else {
self.validator
.validate(py, input, extra, definitions, recursion_guard)?
};

let instance = create_class(self.class.as_ref(py))?;
let instance_ref = instance.as_ref(py);

if self.root_model {
force_setattr(py, instance_ref, intern!(py, ROOT_FIELD), output)?;
} else {
let (model_dict, model_extra, val_fields_set): (&PyAny, &PyAny, &PyAny) = output.extract(py)?;
let fields_set = existing_fields_set.unwrap_or(val_fields_set);
set_model_attrs(instance_ref, model_dict, model_extra, fields_set)?;
}
self.call_post_init(py, instance, input, extra)
}

fn call_post_init<'s, 'data>(
&'s self,
py: Python<'data>,
Expand All @@ -281,13 +335,6 @@ impl ModelValidator {
}
Ok(instance)
}

fn create_class(&self, model_dict: &PyAny, model_extra: &PyAny, fields_set: &PyAny) -> PyResult<PyObject> {
let py = model_dict.py();
let instance = create_class(self.class.as_ref(py))?;
set_model_attrs(instance.as_ref(py), model_dict, model_extra, fields_set)?;
Ok(instance)
}
}

/// based on the following but with the second argument of new_func set to an empty tuple as required
Expand Down
26 changes: 26 additions & 0 deletions tests/benchmarks/test_micro_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,3 +1313,29 @@ def test_validate_literal(
assert res == expected_val_res

benchmark(validator.validate_json, input_json)


@pytest.mark.benchmark(group='root_model')
def test_core_root_model(benchmark):
class MyModel:
__slots__ = 'root'
root: List[int]

v = SchemaValidator(
core_schema.model_schema(MyModel, core_schema.list_schema(core_schema.int_schema()), root_model=True)
)
assert v.validate_python([1, 2, '3']).root == [1, 2, 3]
input_data = list(range(100))
benchmark(v.validate_python, input_data)


@skip_pydantic
@pytest.mark.benchmark(group='root_model')
def test_v1_root_model(benchmark):
class MyModel(BaseModel):
__root__: List[int]

assert MyModel.parse_obj([1, 2, '3']).__root__ == [1, 2, 3]
input_data = list(range(100))

benchmark(MyModel.parse_obj, input_data)
43 changes: 40 additions & 3 deletions tests/validators/test_model_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def test_model_init():
m2 = MyModel()
ans = v.validate_python({'field_a': 'test', 'field_b': 12}, self_instance=m2)
assert ans == m2
assert m.field_a == 'test'
assert m.field_b == 12
assert m.__pydantic_fields_set__ == {'field_a', 'field_b'}
assert ans.field_a == 'test'
assert ans.field_b == 12
assert ans.__pydantic_fields_set__ == {'field_a', 'field_b'}


def test_model_init_nested():
Expand Down Expand Up @@ -381,3 +381,40 @@ def __init__(self, **data):
('inner', {'a': 1, 'b': 3}, {'b', 'z'}, {'z': 1}),
('outer', {'a': 2, 'b': IsInstance(ModelInner)}, {'c', 'a', 'b'}, {'c': 1}),
]


def test_model_custom_init_revalidate():
calls = []

class Model:
__slots__ = '__dict__', '__pydantic_extra__', '__pydantic_fields_set__'

def __init__(self, **kwargs):
calls.append(repr(kwargs))
self.__dict__.update(kwargs)
self.__pydantic_fields_set__ = {'custom'}
self.__pydantic_extra__ = None

v = SchemaValidator(
core_schema.model_schema(
Model,
core_schema.model_fields_schema({'a': core_schema.model_field(core_schema.int_schema())}),
custom_init=True,
config=dict(revalidate_instances='always'),
)
)

m = v.validate_python({'a': '1'})
assert isinstance(m, Model)
assert m.a == '1'
assert m.__pydantic_fields_set__ == {'custom'}
assert calls == ["{'a': '1'}"]
m.x = 4

m2 = v.validate_python(m)
assert m2 is not m
assert isinstance(m2, Model)
assert m2.a == '1'
assert m2.__dict__ == {'a': '1', 'x': 4}
assert m2.__pydantic_fields_set__ == {'custom'}
assert calls == ["{'a': '1'}", "{'a': '1', 'x': 4}"]
Loading

0 comments on commit b863be7

Please sign in to comment.