Skip to content

Commit 99136a5

Browse files
committed
refactor common extraction logic
1 parent 3267736 commit 99136a5

File tree

5 files changed

+65
-77
lines changed

5 files changed

+65
-77
lines changed

src/common/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
pub(crate) mod prebuilt;
12
pub(crate) mod union;

src/common/prebuilt.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
use pyo3::intern;
2+
use pyo3::prelude::*;
3+
use pyo3::types::{PyAny, PyDict, PyType};
4+
5+
use crate::tools::SchemaDict;
6+
7+
pub fn get_prebuilt<T>(
8+
type_: &str,
9+
schema: &Bound<'_, PyDict>,
10+
prebuilt_attr_name: &str,
11+
extractor: impl FnOnce(Bound<'_, PyAny>) -> PyResult<T>,
12+
) -> PyResult<Option<T>> {
13+
let py = schema.py();
14+
15+
// we can only use prebuilt validators / serializers from models, typed dicts, and dataclasses
16+
// however, we don't want to use a prebuilt structure from dataclasses if we have a generic_origin
17+
// because the validator / serializer is cached on the unparametrized dataclass
18+
if !matches!(type_, "model" | "typed-dict")
19+
|| matches!(type_, "dataclass") && schema.contains(intern!(py, "generic_origin"))?
20+
{
21+
return Ok(None);
22+
}
23+
24+
let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?;
25+
26+
// Note: we NEED to use the __dict__ here (and perform get_item calls rather than getattr)
27+
// because we don't want to fetch prebuilt validators from parent classes.
28+
// We don't downcast here because __dict__ on a class is a readonly mappingproxy,
29+
// so we can just leave it as is and do get_item checks.
30+
let class_dict = class.getattr(intern!(py, "__dict__"))?;
31+
32+
let is_complete: bool = class_dict
33+
.get_item(intern!(py, "__pydantic_complete__"))
34+
.is_ok_and(|b| b.extract().unwrap_or(false));
35+
36+
if !is_complete {
37+
return Ok(None);
38+
}
39+
40+
// Retrieve the prebuilt validator / serializer if available
41+
let prebuilt: Bound<'_, PyAny> = class_dict.get_item(prebuilt_attr_name)?;
42+
extractor(prebuilt).map(Some)
43+
}

src/serializers/prebuilt.rs

Lines changed: 14 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,30 @@
11
use std::borrow::Cow;
22

3-
use pyo3::intern;
43
use pyo3::prelude::*;
5-
use pyo3::types::{PyDict, PyType};
4+
use pyo3::types::PyDict;
65

7-
use crate::tools::SchemaDict;
6+
use crate::common::prebuilt::get_prebuilt;
87
use crate::SchemaSerializer;
98

109
use super::extra::Extra;
1110
use super::shared::{CombinedSerializer, TypeSerializer};
1211

1312
#[derive(Debug)]
1413
pub struct PrebuiltSerializer {
15-
serializer: Py<SchemaSerializer>,
14+
schema_serializer: Py<SchemaSerializer>,
1615
}
1716

1817
impl PrebuiltSerializer {
1918
pub fn try_get_from_schema(type_: &str, schema: &Bound<'_, PyDict>) -> PyResult<Option<CombinedSerializer>> {
20-
let py = schema.py();
21-
22-
// we can only use prebuilt serializeres from models, typed dicts, and dataclasses
23-
// however, we don't want to use a prebuilt serializer for dataclasses if we have a generic_origin
24-
// because __pydantic_serializer__ is cached on the unparametrized dataclass
25-
if !matches!(type_, "model" | "typed-dict")
26-
|| matches!(type_, "dataclass") && schema.contains(intern!(py, "generic_origin"))?
27-
{
28-
return Ok(None);
29-
}
30-
31-
let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?;
32-
33-
// Note: we NEED to use the __dict__ here (and perform get_item calls rather than getattr)
34-
// because we don't want to fetch prebuilt validators from parent classes.
35-
// We don't downcast here because __dict__ on a class is a readonly mappingproxy,
36-
// so we can just leave it as is and do get_item checks.
37-
let class_dict = class.getattr(intern!(py, "__dict__"))?;
38-
39-
let is_complete: bool = class_dict
40-
.get_item(intern!(py, "__pydantic_complete__"))
41-
.is_ok_and(|b| b.extract().unwrap_or(false));
42-
43-
if !is_complete {
44-
return Ok(None);
45-
}
46-
47-
// Retrieve the prebuilt validator if available
48-
let prebuilt_serializer: Bound<'_, PyAny> = class_dict.get_item(intern!(py, "__pydantic_serializer__"))?;
49-
let serializer: Py<SchemaSerializer> = prebuilt_serializer.extract()?;
50-
51-
Ok(Some(Self { serializer }.into()))
19+
get_prebuilt(type_, schema, "__pydantic_serializer__", |py_any| {
20+
py_any
21+
.extract::<Py<SchemaSerializer>>()
22+
.map(|schema_serializer| Self { schema_serializer }.into())
23+
})
5224
}
5325
}
5426

55-
impl_py_gc_traverse!(PrebuiltSerializer { serializer });
27+
impl_py_gc_traverse!(PrebuiltSerializer { schema_serializer });
5628

5729
impl TypeSerializer for PrebuiltSerializer {
5830
fn to_python(
@@ -62,14 +34,14 @@ impl TypeSerializer for PrebuiltSerializer {
6234
exclude: Option<&Bound<'_, PyAny>>,
6335
extra: &Extra,
6436
) -> PyResult<PyObject> {
65-
self.serializer
37+
self.schema_serializer
6638
.get()
6739
.serializer
6840
.to_python(value, include, exclude, extra)
6941
}
7042

7143
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
72-
self.serializer.get().serializer.json_key(key, extra)
44+
self.schema_serializer.get().serializer.json_key(key, extra)
7345
}
7446

7547
fn serde_serialize<S: serde::ser::Serializer>(
@@ -80,17 +52,17 @@ impl TypeSerializer for PrebuiltSerializer {
8052
exclude: Option<&Bound<'_, PyAny>>,
8153
extra: &Extra,
8254
) -> Result<S::Ok, S::Error> {
83-
self.serializer
55+
self.schema_serializer
8456
.get()
8557
.serializer
8658
.serde_serialize(value, serializer, include, exclude, extra)
8759
}
8860

8961
fn get_name(&self) -> &str {
90-
self.serializer.get().serializer.get_name()
62+
self.schema_serializer.get().serializer.get_name()
9163
}
9264

9365
fn retry_with_lax_check(&self) -> bool {
94-
self.serializer.get().serializer.retry_with_lax_check()
66+
self.schema_serializer.get().serializer.retry_with_lax_check()
9567
}
9668
}

src/validators/prebuilt.rs

Lines changed: 7 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
use pyo3::intern;
21
use pyo3::prelude::*;
3-
use pyo3::types::{PyDict, PyType};
2+
use pyo3::types::PyDict;
43

4+
use crate::common::prebuilt::get_prebuilt;
55
use crate::errors::ValResult;
66
use crate::input::Input;
7-
use crate::tools::SchemaDict;
87

98
use super::ValidationState;
109
use super::{CombinedValidator, SchemaValidator, Validator};
@@ -16,38 +15,11 @@ pub struct PrebuiltValidator {
1615

1716
impl PrebuiltValidator {
1817
pub fn try_get_from_schema(type_: &str, schema: &Bound<'_, PyDict>) -> PyResult<Option<CombinedValidator>> {
19-
let py = schema.py();
20-
21-
// we can only use prebuilt validators from models, typed dicts, and dataclasses
22-
// however, we don't want to use a prebuilt validator for dataclasses if we have a generic_origin
23-
// because __pydantic_validator__ is cached on the unparametrized dataclass
24-
if !matches!(type_, "model" | "typed-dict")
25-
|| matches!(type_, "dataclass") && schema.contains(intern!(py, "generic_origin"))?
26-
{
27-
return Ok(None);
28-
}
29-
30-
let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?;
31-
32-
// Note: we NEED to use the __dict__ here (and perform get_item calls rather than getattr)
33-
// because we don't want to fetch prebuilt validators from parent classes.
34-
// We don't downcast here because __dict__ on a class is a readonly mappingproxy,
35-
// so we can just leave it as is and do get_item checks.
36-
let class_dict = class.getattr(intern!(py, "__dict__"))?;
37-
38-
let is_complete: bool = class_dict
39-
.get_item(intern!(py, "__pydantic_complete__"))
40-
.is_ok_and(|b| b.extract().unwrap_or(false));
41-
42-
if !is_complete {
43-
return Ok(None);
44-
}
45-
46-
// Retrieve the prebuilt validator if available
47-
let prebuilt_validator = class_dict.get_item(intern!(py, "__pydantic_validator__"))?;
48-
let schema_validator = prebuilt_validator.extract::<Py<SchemaValidator>>()?;
49-
50-
Ok(Some(Self { schema_validator }.into()))
18+
get_prebuilt(type_, schema, "__pydantic_validator__", |py_any| {
19+
py_any
20+
.extract::<Py<SchemaValidator>>()
21+
.map(|schema_validator| Self { schema_validator }.into())
22+
})
5123
}
5224
}
5325

File renamed without changes.

0 commit comments

Comments
 (0)