diff --git a/src/recursion_guard.rs b/src/recursion_guard.rs index fe5b1bcdd9f..d3302c0e90a 100644 --- a/src/recursion_guard.rs +++ b/src/recursion_guard.rs @@ -12,8 +12,59 @@ type RecursionKey = ( /// This is used to avoid cyclic references in input data causing recursive validation and a nasty segmentation fault. /// It's used in `validators/definition` to detect when a reference is reused within itself. +pub(crate) struct RecursionGuard<'a, S: ContainsRecursionState> { + state: &'a mut S, + obj_id: usize, + node_id: usize, +} + +pub(crate) enum RecursionError { + /// Cyclic reference detected + Cyclic, + /// Recursion limit exceeded + Depth, +} + +impl<S: ContainsRecursionState> RecursionGuard<'_, S> { + /// Creates a recursion guard for the given object and node id. + /// + /// When dropped, this will release the recursion for the given object and node id. + pub fn new(state: &'_ mut S, obj_id: usize, node_id: usize) -> Result<RecursionGuard<'_, S>, RecursionError> { + state.access_recursion_state(|state| { + if !state.insert(obj_id, node_id) { + return Err(RecursionError::Cyclic); + } + if state.incr_depth() { + return Err(RecursionError::Depth); + } + Ok(()) + })?; + Ok(RecursionGuard { state, obj_id, node_id }) + } + + /// Retrieves the underlying state for further use. + pub fn state(&mut self) -> &mut S { + self.state + } +} + +impl<S: ContainsRecursionState> Drop for RecursionGuard<'_, S> { + fn drop(&mut self) { + self.state.access_recursion_state(|state| { + state.decr_depth(); + state.remove(self.obj_id, self.node_id); + }); + } +} + +/// This trait is used to retrieve the recursion state from some other type +pub(crate) trait ContainsRecursionState { + fn access_recursion_state<R>(&mut self, f: impl FnOnce(&mut RecursionState) -> R) -> R; +} + +/// State for the RecursionGuard. Can also be used directly to increase / decrease depth. #[derive(Debug, Clone, Default)] -pub struct RecursionGuard { +pub struct RecursionState { ids: RecursionStack, // depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just // use one number for all validators @@ -31,11 +82,11 @@ pub const RECURSION_GUARD_LIMIT: u8 = if cfg!(any(target_family = "wasm", all(wi 255 }; -impl RecursionGuard { +impl RecursionState { // insert a new value // * return `false` if the stack already had it in it // * return `true` if the stack didn't have it in it and it was inserted - pub fn insert(&mut self, obj_id: usize, node_id: usize) -> bool { + fn insert(&mut self, obj_id: usize, node_id: usize) -> bool { self.ids.insert((obj_id, node_id)) } @@ -68,7 +119,7 @@ impl RecursionGuard { self.depth = self.depth.saturating_sub(1); } - pub fn remove(&mut self, obj_id: usize, node_id: usize) { + fn remove(&mut self, obj_id: usize, node_id: usize) { self.ids.remove(&(obj_id, node_id)); } } @@ -98,7 +149,7 @@ impl RecursionStack { // insert a new value // * return `false` if the stack already had it in it // * return `true` if the stack didn't have it in it and it was inserted - pub fn insert(&mut self, v: RecursionKey) -> bool { + fn insert(&mut self, v: RecursionKey) -> bool { match self { Self::Array { data, len } => { if *len < ARRAY_SIZE { @@ -129,7 +180,7 @@ impl RecursionStack { } } - pub fn remove(&mut self, v: &RecursionKey) { + fn remove(&mut self, v: &RecursionKey) { match self { Self::Array { data, len } => { *len = len.checked_sub(1).expect("remove from empty recursion guard"); diff --git a/src/serializers/extra.rs b/src/serializers/extra.rs index b3978613a93..8d598d46b83 100644 --- a/src/serializers/extra.rs +++ b/src/serializers/extra.rs @@ -10,20 +10,23 @@ use serde::ser::Error; use super::config::SerializationConfig; use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER}; use super::ob_type::ObTypeLookup; +use crate::recursion_guard::ContainsRecursionState; +use crate::recursion_guard::RecursionError; use crate::recursion_guard::RecursionGuard; +use crate::recursion_guard::RecursionState; /// this is ugly, would be much better if extra could be stored in `SerializationState` /// then `SerializationState` got a `serialize_infer` method, but I couldn't get it to work pub(crate) struct SerializationState { warnings: CollectWarnings, - rec_guard: SerRecursionGuard, + rec_guard: SerRecursionState, config: SerializationConfig, } impl SerializationState { pub fn new(timedelta_mode: &str, bytes_mode: &str, inf_nan_mode: &str) -> PyResult<Self> { let warnings = CollectWarnings::new(false); - let rec_guard = SerRecursionGuard::default(); + let rec_guard = SerRecursionState::default(); let config = SerializationConfig::from_args(timedelta_mode, bytes_mode, inf_nan_mode)?; Ok(Self { warnings, @@ -77,7 +80,7 @@ pub(crate) struct Extra<'a> { pub exclude_none: bool, pub round_trip: bool, pub config: &'a SerializationConfig, - pub rec_guard: &'a SerRecursionGuard, + pub rec_guard: &'a SerRecursionState, // the next two are used for union logic pub check: SerCheck, // data representing the current model field @@ -101,7 +104,7 @@ impl<'a> Extra<'a> { exclude_none: bool, round_trip: bool, config: &'a SerializationConfig, - rec_guard: &'a SerRecursionGuard, + rec_guard: &'a SerRecursionState, serialize_unknown: bool, fallback: Option<&'a PyAny>, ) -> Self { @@ -124,6 +127,22 @@ impl<'a> Extra<'a> { } } + pub fn recursion_guard<'x, 'y>( + // TODO: this double reference is a bit if a hack, but it's necessary because the recursion + // guard is not passed around with &mut reference + // + // See how validation has &mut ValidationState passed around; we should aim to refactor + // to match that. + self: &'x mut &'y Self, + value: &PyAny, + def_ref_id: usize, + ) -> PyResult<RecursionGuard<'x, &'y Self>> { + RecursionGuard::new(self, value.as_ptr() as usize, def_ref_id).map_err(|e| match e { + RecursionError::Depth => PyValueError::new_err("Circular reference detected (depth exceeded)"), + RecursionError::Cyclic => PyValueError::new_err("Circular reference detected (id repeated)"), + }) + } + pub fn serialize_infer<'py>(&'py self, value: &'py PyAny) -> super::infer::SerializeInfer<'py> { super::infer::SerializeInfer::new(value, None, None, self) } @@ -157,7 +176,7 @@ pub(crate) struct ExtraOwned { exclude_none: bool, round_trip: bool, config: SerializationConfig, - rec_guard: SerRecursionGuard, + rec_guard: SerRecursionState, check: SerCheck, model: Option<PyObject>, field_name: Option<String>, @@ -340,29 +359,12 @@ impl CollectWarnings { #[derive(Default, Clone)] #[cfg_attr(debug_assertions, derive(Debug))] -pub struct SerRecursionGuard { - guard: RefCell<RecursionGuard>, +pub struct SerRecursionState { + guard: RefCell<RecursionState>, } -impl SerRecursionGuard { - pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<usize> { - let id = value.as_ptr() as usize; - let mut guard = self.guard.borrow_mut(); - - if guard.insert(id, def_ref_id) { - if guard.incr_depth() { - Err(PyValueError::new_err("Circular reference detected (depth exceeded)")) - } else { - Ok(id) - } - } else { - Err(PyValueError::new_err("Circular reference detected (id repeated)")) - } - } - - pub fn pop(&self, id: usize, def_ref_id: usize) { - let mut guard = self.guard.borrow_mut(); - guard.decr_depth(); - guard.remove(id, def_ref_id); +impl ContainsRecursionState for &'_ Extra<'_> { + fn access_recursion_state<R>(&mut self, f: impl FnOnce(&mut RecursionState) -> R) -> R { + f(&mut self.rec_guard.guard.borrow_mut()) } } diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index 5ddf77597d3..487d0d091f6 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -40,19 +40,22 @@ pub(crate) fn infer_to_python_known( value: &PyAny, include: Option<&PyAny>, exclude: Option<&PyAny>, - extra: &Extra, + mut extra: &Extra, ) -> PyResult<PyObject> { let py = value.py(); - let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID) { - Ok(id) => id, + + let mode = extra.mode; + let mut guard = match extra.recursion_guard(value, INFER_DEF_REF_ID) { + Ok(v) => v, Err(e) => { - return match extra.mode { + return match mode { SerMode::Json => Err(e), // if recursion is detected by we're serializing to python, we just return the value _ => Ok(value.into_py(py)), }; } }; + let extra = guard.state(); macro_rules! serialize_seq { ($t:ty) => { @@ -220,7 +223,6 @@ pub(crate) fn infer_to_python_known( if let Some(fallback) = extra.fallback { let next_value = fallback.call1((value,))?; let next_result = infer_to_python(next_value, include, exclude, extra); - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); return next_result; } else if extra.serialize_unknown { serialize_unknown(value).into_py(py) @@ -267,7 +269,6 @@ pub(crate) fn infer_to_python_known( if let Some(fallback) = extra.fallback { let next_value = fallback.call1((value,))?; let next_result = infer_to_python(next_value, include, exclude, extra); - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); return next_result; } value.into_py(py) @@ -275,7 +276,6 @@ pub(crate) fn infer_to_python_known( _ => value.into_py(py), }, }; - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); Ok(value) } @@ -332,18 +332,21 @@ pub(crate) fn infer_serialize_known<S: Serializer>( serializer: S, include: Option<&PyAny>, exclude: Option<&PyAny>, - extra: &Extra, + mut extra: &Extra, ) -> Result<S::Ok, S::Error> { - let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID).map_err(py_err_se_err) { + let extra_serialize_unknown = extra.serialize_unknown; + let mut guard = match extra.recursion_guard(value, INFER_DEF_REF_ID) { Ok(v) => v, Err(e) => { - return if extra.serialize_unknown { + return if extra_serialize_unknown { serializer.serialize_str("...") } else { - Err(e) - } + Err(py_err_se_err(e)) + }; } }; + let extra = guard.state(); + macro_rules! serialize { ($t:ty) => { match value.extract::<$t>() { @@ -506,7 +509,6 @@ pub(crate) fn infer_serialize_known<S: Serializer>( if let Some(fallback) = extra.fallback { let next_value = fallback.call1((value,)).map_err(py_err_se_err)?; let next_result = infer_serialize(next_value, serializer, include, exclude, extra); - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); return next_result; } else if extra.serialize_unknown { serializer.serialize_str(&serialize_unknown(value)) @@ -520,7 +522,6 @@ pub(crate) fn infer_serialize_known<S: Serializer>( } } }; - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); ser_result } diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index 8159691cb5a..7d9c5347c49 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -10,7 +10,7 @@ use crate::py_gc::PyGcTraverse; use config::SerializationConfig; pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue}; -use extra::{CollectWarnings, SerRecursionGuard}; +use extra::{CollectWarnings, SerRecursionState}; pub(crate) use extra::{Extra, SerMode, SerializationState}; pub use shared::CombinedSerializer; use shared::{to_json_bytes, BuildSerializer, TypeSerializer}; @@ -52,7 +52,7 @@ impl SchemaSerializer { exclude_defaults: bool, exclude_none: bool, round_trip: bool, - rec_guard: &'a SerRecursionGuard, + rec_guard: &'a SerRecursionState, serialize_unknown: bool, fallback: Option<&'a PyAny>, ) -> Extra<'b> { @@ -113,7 +113,7 @@ impl SchemaSerializer { ) -> PyResult<PyObject> { let mode: SerMode = mode.into(); let warnings = CollectWarnings::new(warnings); - let rec_guard = SerRecursionGuard::default(); + let rec_guard = SerRecursionState::default(); let extra = self.build_extra( py, &mode, @@ -152,7 +152,7 @@ impl SchemaSerializer { fallback: Option<&PyAny>, ) -> PyResult<PyObject> { let warnings = CollectWarnings::new(warnings); - let rec_guard = SerRecursionGuard::default(); + let rec_guard = SerRecursionState::default(); let extra = self.build_extra( py, &SerMode::Json, diff --git a/src/serializers/type_serializers/definitions.rs b/src/serializers/type_serializers/definitions.rs index 99dae5bcd41..2f98a94e072 100644 --- a/src/serializers/type_serializers/definitions.rs +++ b/src/serializers/type_serializers/definitions.rs @@ -66,14 +66,12 @@ impl TypeSerializer for DefinitionRefSerializer { value: &PyAny, include: Option<&PyAny>, exclude: Option<&PyAny>, - extra: &Extra, + mut extra: &Extra, ) -> PyResult<PyObject> { self.definition.read(|comb_serializer| { let comb_serializer = comb_serializer.unwrap(); - let value_id = extra.rec_guard.add(value, self.definition.id())?; - let r = comb_serializer.to_python(value, include, exclude, extra); - extra.rec_guard.pop(value_id, self.definition.id()); - r + let mut guard = extra.recursion_guard(value, self.definition.id())?; + comb_serializer.to_python(value, include, exclude, guard.state()) }) } @@ -87,17 +85,14 @@ impl TypeSerializer for DefinitionRefSerializer { serializer: S, include: Option<&PyAny>, exclude: Option<&PyAny>, - extra: &Extra, + mut extra: &Extra, ) -> Result<S::Ok, S::Error> { self.definition.read(|comb_serializer| { let comb_serializer = comb_serializer.unwrap(); - let value_id = extra - .rec_guard - .add(value, self.definition.id()) + let mut guard = extra + .recursion_guard(value, self.definition.id()) .map_err(py_err_se_err)?; - let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra); - extra.rec_guard.pop(value_id, self.definition.id()); - r + comb_serializer.serde_serialize(value, serializer, include, exclude, guard.state()) }) } diff --git a/src/validators/definitions.rs b/src/validators/definitions.rs index e8c67a690d1..e4bc270c226 100644 --- a/src/validators/definitions.rs +++ b/src/validators/definitions.rs @@ -6,6 +6,7 @@ use crate::definitions::DefinitionRef; use crate::errors::{ErrorTypeDefaults, ValError, ValResult}; use crate::input::Input; +use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; @@ -76,18 +77,10 @@ impl Validator for DefinitionRefValidator { self.definition.read(|validator| { let validator = validator.unwrap(); if let Some(id) = input.identity() { - if state.recursion_guard.insert(id, self.definition.id()) { - if state.recursion_guard.incr_depth() { - return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)); - } - let output = validator.validate(py, input, state); - state.recursion_guard.remove(id, self.definition.id()); - state.recursion_guard.decr_depth(); - output - } else { - // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` - Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)) - } + let Ok(mut guard) = RecursionGuard::new(state, id, self.definition.id()) else { + return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)); + }; + validator.validate(py, input, guard.state()) } else { validator.validate(py, input, state) } @@ -105,18 +98,10 @@ impl Validator for DefinitionRefValidator { self.definition.read(|validator| { let validator = validator.unwrap(); if let Some(id) = obj.identity() { - if state.recursion_guard.insert(id, self.definition.id()) { - if state.recursion_guard.incr_depth() { - return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)); - } - let output = validator.validate_assignment(py, obj, field_name, field_value, state); - state.recursion_guard.remove(id, self.definition.id()); - state.recursion_guard.decr_depth(); - output - } else { - // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` - Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)) - } + let Ok(mut guard) = RecursionGuard::new(state, id, self.definition.id()) else { + return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)); + }; + validator.validate_assignment(py, obj, field_name, field_value, guard.state()) } else { validator.validate_assignment(py, obj, field_name, field_value, state) } diff --git a/src/validators/generator.rs b/src/validators/generator.rs index 3b5fedd9723..a366da4312a 100644 --- a/src/validators/generator.rs +++ b/src/validators/generator.rs @@ -6,7 +6,7 @@ use pyo3::types::PyDict; use crate::errors::{ErrorType, LocItem, ValError, ValResult}; use crate::input::{GenericIterator, Input}; -use crate::recursion_guard::RecursionGuard; +use crate::recursion_guard::RecursionState; use crate::tools::SchemaDict; use crate::ValidationError; @@ -212,7 +212,7 @@ pub struct InternalValidator { from_attributes: Option<bool>, context: Option<PyObject>, self_instance: Option<PyObject>, - recursion_guard: RecursionGuard, + recursion_guard: RecursionState, pub(crate) exactness: Option<Exactness>, validation_mode: InputType, hide_input_in_errors: bool, diff --git a/src/validators/mod.rs b/src/validators/mod.rs index fc158408077..496da5c739d 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -13,7 +13,7 @@ use crate::definitions::{Definitions, DefinitionsBuilder}; use crate::errors::{LocItem, ValError, ValResult, ValidationError}; use crate::input::{Input, InputType, StringMapping}; use crate::py_gc::PyGcTraverse; -use crate::recursion_guard::RecursionGuard; +use crate::recursion_guard::RecursionState; use crate::tools::SchemaDict; mod any; @@ -263,7 +263,7 @@ impl SchemaValidator { self_instance: None, }; - let guard = &mut RecursionGuard::default(); + let guard = &mut RecursionState::default(); let mut state = ValidationState::new(extra, guard); self.validator .validate_assignment(py, obj, field_name, field_value, &mut state) @@ -280,7 +280,7 @@ impl SchemaValidator { context, self_instance: None, }; - let recursion_guard = &mut RecursionGuard::default(); + let recursion_guard = &mut RecursionState::default(); let mut state = ValidationState::new(extra, recursion_guard); let r = self.validator.default_value(py, None::<i64>, &mut state); match r { @@ -326,7 +326,7 @@ impl SchemaValidator { where 's: 'data, { - let mut recursion_guard = RecursionGuard::default(); + let mut recursion_guard = RecursionState::default(); let mut state = ValidationState::new( Extra::new(strict, from_attributes, context, self_instance, input_type), &mut recursion_guard, @@ -378,7 +378,7 @@ impl<'py> SelfValidator<'py> { } pub fn validate_schema(&self, py: Python<'py>, schema: &'py PyAny, strict: Option<bool>) -> PyResult<&'py PyAny> { - let mut recursion_guard = RecursionGuard::default(); + let mut recursion_guard = RecursionState::default(); let mut state = ValidationState::new( Extra::new(strict, None, None, None, InputType::Python), &mut recursion_guard, diff --git a/src/validators/validation_state.rs b/src/validators/validation_state.rs index aacd7d2af9a..4f241d7680c 100644 --- a/src/validators/validation_state.rs +++ b/src/validators/validation_state.rs @@ -1,4 +1,4 @@ -use crate::recursion_guard::RecursionGuard; +use crate::recursion_guard::{ContainsRecursionState, RecursionState}; use super::Extra; @@ -10,14 +10,14 @@ pub enum Exactness { } pub struct ValidationState<'a> { - pub recursion_guard: &'a mut RecursionGuard, + pub recursion_guard: &'a mut RecursionState, pub exactness: Option<Exactness>, // deliberately make Extra readonly extra: Extra<'a>, } impl<'a> ValidationState<'a> { - pub fn new(extra: Extra<'a>, recursion_guard: &'a mut RecursionGuard) -> Self { + pub fn new(extra: Extra<'a>, recursion_guard: &'a mut RecursionState) -> Self { Self { recursion_guard, // Don't care about exactness unless doing union validation exactness: None, @@ -84,6 +84,12 @@ impl<'a> ValidationState<'a> { } } +impl ContainsRecursionState for ValidationState<'_> { + fn access_recursion_state<R>(&mut self, f: impl FnOnce(&mut RecursionState) -> R) -> R { + f(self.recursion_guard) + } +} + pub struct ValidationStateWithReboundExtra<'state, 'a> { state: &'state mut ValidationState<'a>, old_extra: Extra<'a>,