Skip to content

Commit

Permalink
Support default factories taking validated data as an argument (#1491)
Browse files Browse the repository at this point in the history
  • Loading branch information
Viicos authored Oct 23, 2024
1 parent 288dd1c commit ff08c20
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 12 deletions.
10 changes: 7 additions & 3 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2375,7 +2375,8 @@ class WithDefaultSchema(TypedDict, total=False):
type: Required[Literal['default']]
schema: Required[CoreSchema]
default: Any
default_factory: Callable[[], Any]
default_factory: Union[Callable[[], Any], Callable[[Dict[str, Any]], Any]]
default_factory_takes_data: bool
on_error: Literal['raise', 'omit', 'default'] # default: 'raise'
validate_default: bool # default: False
strict: bool
Expand All @@ -2388,7 +2389,8 @@ def with_default_schema(
schema: CoreSchema,
*,
default: Any = PydanticUndefined,
default_factory: Callable[[], Any] | None = None,
default_factory: Union[Callable[[], Any], Callable[[Dict[str, Any]], Any], None] = None,
default_factory_takes_data: bool | None = None,
on_error: Literal['raise', 'omit', 'default'] | None = None,
validate_default: bool | None = None,
strict: bool | None = None,
Expand All @@ -2413,7 +2415,8 @@ def with_default_schema(
Args:
schema: The schema to add a default value to
default: The default value to use
default_factory: A function that returns the default value to use
default_factory: A callable that returns the default value to use
default_factory_takes_data: Whether the default factory takes a validated data argument
on_error: What to do if the schema validation fails. One of 'raise', 'omit', 'default'
validate_default: Whether the default value should be validated
strict: Whether the underlying schema should be validated with strict mode
Expand All @@ -2425,6 +2428,7 @@ def with_default_schema(
type='default',
schema=schema,
default_factory=default_factory,
default_factory_takes_data=default_factory_takes_data,
on_error=on_error,
validate_default=validate_default,
strict=strict,
Expand Down
10 changes: 9 additions & 1 deletion src/serializers/type_serializers/with_default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ impl TypeSerializer for WithDefaultSerializer {
}

fn get_default(&self, py: Python) -> PyResult<Option<PyObject>> {
self.default.default_value(py)
if let DefaultType::DefaultFactory(_, _takes_data @ true) = self.default {
// We currently don't compute the default if the default factory takes
// the data from other fields.
Ok(None)
} else {
self.default.default_value(
py, &None, // Won't be used.
)
}
}
}
29 changes: 23 additions & 6 deletions src/validators/with_default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ fn get_deepcopy(py: Python) -> PyResult<PyObject> {
pub enum DefaultType {
None,
Default(PyObject),
DefaultFactory(PyObject),
DefaultFactory(PyObject, bool),
}

impl DefaultType {
Expand All @@ -37,23 +37,40 @@ impl DefaultType {
) {
(Some(_), Some(_)) => py_schema_err!("'default' and 'default_factory' cannot be used together"),
(Some(default), None) => Ok(Self::Default(default)),
(None, Some(default_factory)) => Ok(Self::DefaultFactory(default_factory)),
(None, Some(default_factory)) => Ok(Self::DefaultFactory(
default_factory,
schema
.get_as::<bool>(intern!(py, "default_factory_takes_data"))?
.unwrap_or(false),
)),
(None, None) => Ok(Self::None),
}
}

pub fn default_value(&self, py: Python) -> PyResult<Option<PyObject>> {
pub fn default_value(&self, py: Python, validated_data: &Option<Bound<PyDict>>) -> PyResult<Option<PyObject>> {
match self {
Self::Default(ref default) => Ok(Some(default.clone_ref(py))),
Self::DefaultFactory(ref default_factory) => Ok(Some(default_factory.call0(py)?)),
Self::DefaultFactory(ref default_factory, ref takes_data) => {
let result = if *takes_data {
if validated_data.is_none() {
default_factory.call1(py, ({},))
} else {
default_factory.call1(py, (validated_data.as_deref().unwrap(),))
}
} else {
default_factory.call0(py)
};

Ok(Some(result?))
}
Self::None => Ok(None),
}
}
}

impl PyGcTraverse for DefaultType {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
if let Self::Default(obj) | Self::DefaultFactory(obj) = self {
if let Self::Default(obj) | Self::DefaultFactory(obj, _) = self {
visit.call(obj)?;
}
Ok(())
Expand Down Expand Up @@ -163,7 +180,7 @@ impl Validator for WithDefaultValidator {
outer_loc: Option<impl Into<LocItem>>,
state: &mut ValidationState<'_, 'py>,
) -> ValResult<Option<PyObject>> {
match self.default.default_value(py)? {
match self.default.default_value(py, &state.extra().data)? {
Some(stored_dft) => {
let dft: Py<PyAny> = if self.copy_default {
let deepcopy_func = COPY_DEEPCOPY.get_or_init(py, || get_deepcopy(py).unwrap());
Expand Down
10 changes: 8 additions & 2 deletions tests/validators/test_with_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,18 @@ def broken():
v.validate_python('wrong')


def test_factory_type_error():
def test_factory_missing_arg():
def broken(x):
return 7

v = SchemaValidator(
{'type': 'default', 'schema': {'type': 'int'}, 'on_error': 'default', 'default_factory': broken}
{
'type': 'default',
'schema': {'type': 'int'},
'on_error': 'default',
'default_factory': broken,
'default_factory_takes_data': False,
}
)
assert v.validate_python(42) == 42
assert v.validate_python('42') == 42
Expand Down

0 comments on commit ff08c20

Please sign in to comment.