diff --git a/src/recursion_guard.rs b/src/recursion_guard.rs
index fe5b1bcdd9f..d3302c0e90a 100644
--- a/src/recursion_guard.rs
+++ b/src/recursion_guard.rs
@@ -12,8 +12,59 @@ type RecursionKey = (
 
 /// This is used to avoid cyclic references in input data causing recursive validation and a nasty segmentation fault.
 /// It's used in `validators/definition` to detect when a reference is reused within itself.
+pub(crate) struct RecursionGuard<'a, S: ContainsRecursionState> {
+    state: &'a mut S,
+    obj_id: usize,
+    node_id: usize,
+}
+
+pub(crate) enum RecursionError {
+    /// Cyclic reference detected
+    Cyclic,
+    /// Recursion limit exceeded
+    Depth,
+}
+
+impl<S: ContainsRecursionState> RecursionGuard<'_, S> {
+    /// Creates a recursion guard for the given object and node id.
+    ///
+    /// When dropped, this will release the recursion for the given object and node id.
+    pub fn new(state: &'_ mut S, obj_id: usize, node_id: usize) -> Result<RecursionGuard<'_, S>, RecursionError> {
+        state.access_recursion_state(|state| {
+            if !state.insert(obj_id, node_id) {
+                return Err(RecursionError::Cyclic);
+            }
+            if state.incr_depth() {
+                return Err(RecursionError::Depth);
+            }
+            Ok(())
+        })?;
+        Ok(RecursionGuard { state, obj_id, node_id })
+    }
+
+    /// Retrieves the underlying state for further use.
+    pub fn state(&mut self) -> &mut S {
+        self.state
+    }
+}
+
+impl<S: ContainsRecursionState> Drop for RecursionGuard<'_, S> {
+    fn drop(&mut self) {
+        self.state.access_recursion_state(|state| {
+            state.decr_depth();
+            state.remove(self.obj_id, self.node_id);
+        });
+    }
+}
+
+/// This trait is used to retrieve the recursion state from some other type
+pub(crate) trait ContainsRecursionState {
+    fn access_recursion_state<R>(&mut self, f: impl FnOnce(&mut RecursionState) -> R) -> R;
+}
+
+/// State for the RecursionGuard. Can also be used directly to increase / decrease depth.
 #[derive(Debug, Clone, Default)]
-pub struct RecursionGuard {
+pub struct RecursionState {
     ids: RecursionStack,
     // depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just
     // use one number for all validators
@@ -31,11 +82,11 @@ pub const RECURSION_GUARD_LIMIT: u8 = if cfg!(any(target_family = "wasm", all(wi
     255
 };
 
-impl RecursionGuard {
+impl RecursionState {
     // insert a new value
     // * return `false` if the stack already had it in it
     // * return `true` if the stack didn't have it in it and it was inserted
-    pub fn insert(&mut self, obj_id: usize, node_id: usize) -> bool {
+    fn insert(&mut self, obj_id: usize, node_id: usize) -> bool {
         self.ids.insert((obj_id, node_id))
     }
 
@@ -68,7 +119,7 @@ impl RecursionGuard {
         self.depth = self.depth.saturating_sub(1);
     }
 
-    pub fn remove(&mut self, obj_id: usize, node_id: usize) {
+    fn remove(&mut self, obj_id: usize, node_id: usize) {
         self.ids.remove(&(obj_id, node_id));
     }
 }
@@ -98,7 +149,7 @@ impl RecursionStack {
     // insert a new value
     // * return `false` if the stack already had it in it
     // * return `true` if the stack didn't have it in it and it was inserted
-    pub fn insert(&mut self, v: RecursionKey) -> bool {
+    fn insert(&mut self, v: RecursionKey) -> bool {
         match self {
             Self::Array { data, len } => {
                 if *len < ARRAY_SIZE {
@@ -129,7 +180,7 @@ impl RecursionStack {
         }
     }
 
-    pub fn remove(&mut self, v: &RecursionKey) {
+    fn remove(&mut self, v: &RecursionKey) {
         match self {
             Self::Array { data, len } => {
                 *len = len.checked_sub(1).expect("remove from empty recursion guard");
diff --git a/src/serializers/extra.rs b/src/serializers/extra.rs
index b3978613a93..8d598d46b83 100644
--- a/src/serializers/extra.rs
+++ b/src/serializers/extra.rs
@@ -10,20 +10,23 @@ use serde::ser::Error;
 use super::config::SerializationConfig;
 use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER};
 use super::ob_type::ObTypeLookup;
+use crate::recursion_guard::ContainsRecursionState;
+use crate::recursion_guard::RecursionError;
 use crate::recursion_guard::RecursionGuard;
+use crate::recursion_guard::RecursionState;
 
 /// this is ugly, would be much better if extra could be stored in `SerializationState`
 /// then `SerializationState` got a `serialize_infer` method, but I couldn't get it to work
 pub(crate) struct SerializationState {
     warnings: CollectWarnings,
-    rec_guard: SerRecursionGuard,
+    rec_guard: SerRecursionState,
     config: SerializationConfig,
 }
 
 impl SerializationState {
     pub fn new(timedelta_mode: &str, bytes_mode: &str, inf_nan_mode: &str) -> PyResult<Self> {
         let warnings = CollectWarnings::new(false);
-        let rec_guard = SerRecursionGuard::default();
+        let rec_guard = SerRecursionState::default();
         let config = SerializationConfig::from_args(timedelta_mode, bytes_mode, inf_nan_mode)?;
         Ok(Self {
             warnings,
@@ -77,7 +80,7 @@ pub(crate) struct Extra<'a> {
     pub exclude_none: bool,
     pub round_trip: bool,
     pub config: &'a SerializationConfig,
-    pub rec_guard: &'a SerRecursionGuard,
+    pub rec_guard: &'a SerRecursionState,
     // the next two are used for union logic
     pub check: SerCheck,
     // data representing the current model field
@@ -101,7 +104,7 @@ impl<'a> Extra<'a> {
         exclude_none: bool,
         round_trip: bool,
         config: &'a SerializationConfig,
-        rec_guard: &'a SerRecursionGuard,
+        rec_guard: &'a SerRecursionState,
         serialize_unknown: bool,
         fallback: Option<&'a PyAny>,
     ) -> Self {
@@ -124,6 +127,22 @@ impl<'a> Extra<'a> {
         }
     }
 
+    pub fn recursion_guard<'x, 'y>(
+        // TODO: this double reference is a bit if a hack, but it's necessary because the recursion
+        // guard is not passed around with &mut reference
+        //
+        // See how validation has &mut ValidationState passed around; we should aim to refactor
+        // to match that.
+        self: &'x mut &'y Self,
+        value: &PyAny,
+        def_ref_id: usize,
+    ) -> PyResult<RecursionGuard<'x, &'y Self>> {
+        RecursionGuard::new(self, value.as_ptr() as usize, def_ref_id).map_err(|e| match e {
+            RecursionError::Depth => PyValueError::new_err("Circular reference detected (depth exceeded)"),
+            RecursionError::Cyclic => PyValueError::new_err("Circular reference detected (id repeated)"),
+        })
+    }
+
     pub fn serialize_infer<'py>(&'py self, value: &'py PyAny) -> super::infer::SerializeInfer<'py> {
         super::infer::SerializeInfer::new(value, None, None, self)
     }
@@ -157,7 +176,7 @@ pub(crate) struct ExtraOwned {
     exclude_none: bool,
     round_trip: bool,
     config: SerializationConfig,
-    rec_guard: SerRecursionGuard,
+    rec_guard: SerRecursionState,
     check: SerCheck,
     model: Option<PyObject>,
     field_name: Option<String>,
@@ -340,29 +359,12 @@ impl CollectWarnings {
 
 #[derive(Default, Clone)]
 #[cfg_attr(debug_assertions, derive(Debug))]
-pub struct SerRecursionGuard {
-    guard: RefCell<RecursionGuard>,
+pub struct SerRecursionState {
+    guard: RefCell<RecursionState>,
 }
 
-impl SerRecursionGuard {
-    pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<usize> {
-        let id = value.as_ptr() as usize;
-        let mut guard = self.guard.borrow_mut();
-
-        if guard.insert(id, def_ref_id) {
-            if guard.incr_depth() {
-                Err(PyValueError::new_err("Circular reference detected (depth exceeded)"))
-            } else {
-                Ok(id)
-            }
-        } else {
-            Err(PyValueError::new_err("Circular reference detected (id repeated)"))
-        }
-    }
-
-    pub fn pop(&self, id: usize, def_ref_id: usize) {
-        let mut guard = self.guard.borrow_mut();
-        guard.decr_depth();
-        guard.remove(id, def_ref_id);
+impl ContainsRecursionState for &'_ Extra<'_> {
+    fn access_recursion_state<R>(&mut self, f: impl FnOnce(&mut RecursionState) -> R) -> R {
+        f(&mut self.rec_guard.guard.borrow_mut())
     }
 }
diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs
index 5ddf77597d3..487d0d091f6 100644
--- a/src/serializers/infer.rs
+++ b/src/serializers/infer.rs
@@ -40,19 +40,22 @@ pub(crate) fn infer_to_python_known(
     value: &PyAny,
     include: Option<&PyAny>,
     exclude: Option<&PyAny>,
-    extra: &Extra,
+    mut extra: &Extra,
 ) -> PyResult<PyObject> {
     let py = value.py();
-    let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID) {
-        Ok(id) => id,
+
+    let mode = extra.mode;
+    let mut guard = match extra.recursion_guard(value, INFER_DEF_REF_ID) {
+        Ok(v) => v,
         Err(e) => {
-            return match extra.mode {
+            return match mode {
                 SerMode::Json => Err(e),
                 // if recursion is detected by we're serializing to python, we just return the value
                 _ => Ok(value.into_py(py)),
             };
         }
     };
+    let extra = guard.state();
 
     macro_rules! serialize_seq {
         ($t:ty) => {
@@ -220,7 +223,6 @@ pub(crate) fn infer_to_python_known(
                 if let Some(fallback) = extra.fallback {
                     let next_value = fallback.call1((value,))?;
                     let next_result = infer_to_python(next_value, include, exclude, extra);
-                    extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
                     return next_result;
                 } else if extra.serialize_unknown {
                     serialize_unknown(value).into_py(py)
@@ -267,7 +269,6 @@ pub(crate) fn infer_to_python_known(
                 if let Some(fallback) = extra.fallback {
                     let next_value = fallback.call1((value,))?;
                     let next_result = infer_to_python(next_value, include, exclude, extra);
-                    extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
                     return next_result;
                 }
                 value.into_py(py)
@@ -275,7 +276,6 @@ pub(crate) fn infer_to_python_known(
             _ => value.into_py(py),
         },
     };
-    extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
     Ok(value)
 }
 
@@ -332,18 +332,21 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
     serializer: S,
     include: Option<&PyAny>,
     exclude: Option<&PyAny>,
-    extra: &Extra,
+    mut extra: &Extra,
 ) -> Result<S::Ok, S::Error> {
-    let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID).map_err(py_err_se_err) {
+    let extra_serialize_unknown = extra.serialize_unknown;
+    let mut guard = match extra.recursion_guard(value, INFER_DEF_REF_ID) {
         Ok(v) => v,
         Err(e) => {
-            return if extra.serialize_unknown {
+            return if extra_serialize_unknown {
                 serializer.serialize_str("...")
             } else {
-                Err(e)
-            }
+                Err(py_err_se_err(e))
+            };
         }
     };
+    let extra = guard.state();
+
     macro_rules! serialize {
         ($t:ty) => {
             match value.extract::<$t>() {
@@ -506,7 +509,6 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
             if let Some(fallback) = extra.fallback {
                 let next_value = fallback.call1((value,)).map_err(py_err_se_err)?;
                 let next_result = infer_serialize(next_value, serializer, include, exclude, extra);
-                extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
                 return next_result;
             } else if extra.serialize_unknown {
                 serializer.serialize_str(&serialize_unknown(value))
@@ -520,7 +522,6 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
             }
         }
     };
-    extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
     ser_result
 }
 
diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs
index 8159691cb5a..7d9c5347c49 100644
--- a/src/serializers/mod.rs
+++ b/src/serializers/mod.rs
@@ -10,7 +10,7 @@ use crate::py_gc::PyGcTraverse;
 
 use config::SerializationConfig;
 pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue};
-use extra::{CollectWarnings, SerRecursionGuard};
+use extra::{CollectWarnings, SerRecursionState};
 pub(crate) use extra::{Extra, SerMode, SerializationState};
 pub use shared::CombinedSerializer;
 use shared::{to_json_bytes, BuildSerializer, TypeSerializer};
@@ -52,7 +52,7 @@ impl SchemaSerializer {
         exclude_defaults: bool,
         exclude_none: bool,
         round_trip: bool,
-        rec_guard: &'a SerRecursionGuard,
+        rec_guard: &'a SerRecursionState,
         serialize_unknown: bool,
         fallback: Option<&'a PyAny>,
     ) -> Extra<'b> {
@@ -113,7 +113,7 @@ impl SchemaSerializer {
     ) -> PyResult<PyObject> {
         let mode: SerMode = mode.into();
         let warnings = CollectWarnings::new(warnings);
-        let rec_guard = SerRecursionGuard::default();
+        let rec_guard = SerRecursionState::default();
         let extra = self.build_extra(
             py,
             &mode,
@@ -152,7 +152,7 @@ impl SchemaSerializer {
         fallback: Option<&PyAny>,
     ) -> PyResult<PyObject> {
         let warnings = CollectWarnings::new(warnings);
-        let rec_guard = SerRecursionGuard::default();
+        let rec_guard = SerRecursionState::default();
         let extra = self.build_extra(
             py,
             &SerMode::Json,
diff --git a/src/serializers/type_serializers/definitions.rs b/src/serializers/type_serializers/definitions.rs
index 99dae5bcd41..2f98a94e072 100644
--- a/src/serializers/type_serializers/definitions.rs
+++ b/src/serializers/type_serializers/definitions.rs
@@ -66,14 +66,12 @@ impl TypeSerializer for DefinitionRefSerializer {
         value: &PyAny,
         include: Option<&PyAny>,
         exclude: Option<&PyAny>,
-        extra: &Extra,
+        mut extra: &Extra,
     ) -> PyResult<PyObject> {
         self.definition.read(|comb_serializer| {
             let comb_serializer = comb_serializer.unwrap();
-            let value_id = extra.rec_guard.add(value, self.definition.id())?;
-            let r = comb_serializer.to_python(value, include, exclude, extra);
-            extra.rec_guard.pop(value_id, self.definition.id());
-            r
+            let mut guard = extra.recursion_guard(value, self.definition.id())?;
+            comb_serializer.to_python(value, include, exclude, guard.state())
         })
     }
 
@@ -87,17 +85,14 @@ impl TypeSerializer for DefinitionRefSerializer {
         serializer: S,
         include: Option<&PyAny>,
         exclude: Option<&PyAny>,
-        extra: &Extra,
+        mut extra: &Extra,
     ) -> Result<S::Ok, S::Error> {
         self.definition.read(|comb_serializer| {
             let comb_serializer = comb_serializer.unwrap();
-            let value_id = extra
-                .rec_guard
-                .add(value, self.definition.id())
+            let mut guard = extra
+                .recursion_guard(value, self.definition.id())
                 .map_err(py_err_se_err)?;
-            let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra);
-            extra.rec_guard.pop(value_id, self.definition.id());
-            r
+            comb_serializer.serde_serialize(value, serializer, include, exclude, guard.state())
         })
     }
 
diff --git a/src/validators/definitions.rs b/src/validators/definitions.rs
index e8c67a690d1..e4bc270c226 100644
--- a/src/validators/definitions.rs
+++ b/src/validators/definitions.rs
@@ -6,6 +6,7 @@ use crate::definitions::DefinitionRef;
 use crate::errors::{ErrorTypeDefaults, ValError, ValResult};
 use crate::input::Input;
 
+use crate::recursion_guard::RecursionGuard;
 use crate::tools::SchemaDict;
 
 use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};
@@ -76,18 +77,10 @@ impl Validator for DefinitionRefValidator {
         self.definition.read(|validator| {
             let validator = validator.unwrap();
             if let Some(id) = input.identity() {
-                if state.recursion_guard.insert(id, self.definition.id()) {
-                    if state.recursion_guard.incr_depth() {
-                        return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
-                    }
-                    let output = validator.validate(py, input, state);
-                    state.recursion_guard.remove(id, self.definition.id());
-                    state.recursion_guard.decr_depth();
-                    output
-                } else {
-                    // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
-                    Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input))
-                }
+                let Ok(mut guard) = RecursionGuard::new(state, id, self.definition.id()) else {
+                    return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
+                };
+                validator.validate(py, input, guard.state())
             } else {
                 validator.validate(py, input, state)
             }
@@ -105,18 +98,10 @@ impl Validator for DefinitionRefValidator {
         self.definition.read(|validator| {
             let validator = validator.unwrap();
             if let Some(id) = obj.identity() {
-                if state.recursion_guard.insert(id, self.definition.id()) {
-                    if state.recursion_guard.incr_depth() {
-                        return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj));
-                    }
-                    let output = validator.validate_assignment(py, obj, field_name, field_value, state);
-                    state.recursion_guard.remove(id, self.definition.id());
-                    state.recursion_guard.decr_depth();
-                    output
-                } else {
-                    // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
-                    Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj))
-                }
+                let Ok(mut guard) = RecursionGuard::new(state, id, self.definition.id()) else {
+                    return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj));
+                };
+                validator.validate_assignment(py, obj, field_name, field_value, guard.state())
             } else {
                 validator.validate_assignment(py, obj, field_name, field_value, state)
             }
diff --git a/src/validators/generator.rs b/src/validators/generator.rs
index 3b5fedd9723..a366da4312a 100644
--- a/src/validators/generator.rs
+++ b/src/validators/generator.rs
@@ -6,7 +6,7 @@ use pyo3::types::PyDict;
 
 use crate::errors::{ErrorType, LocItem, ValError, ValResult};
 use crate::input::{GenericIterator, Input};
-use crate::recursion_guard::RecursionGuard;
+use crate::recursion_guard::RecursionState;
 use crate::tools::SchemaDict;
 use crate::ValidationError;
 
@@ -212,7 +212,7 @@ pub struct InternalValidator {
     from_attributes: Option<bool>,
     context: Option<PyObject>,
     self_instance: Option<PyObject>,
-    recursion_guard: RecursionGuard,
+    recursion_guard: RecursionState,
     pub(crate) exactness: Option<Exactness>,
     validation_mode: InputType,
     hide_input_in_errors: bool,
diff --git a/src/validators/mod.rs b/src/validators/mod.rs
index fc158408077..496da5c739d 100644
--- a/src/validators/mod.rs
+++ b/src/validators/mod.rs
@@ -13,7 +13,7 @@ use crate::definitions::{Definitions, DefinitionsBuilder};
 use crate::errors::{LocItem, ValError, ValResult, ValidationError};
 use crate::input::{Input, InputType, StringMapping};
 use crate::py_gc::PyGcTraverse;
-use crate::recursion_guard::RecursionGuard;
+use crate::recursion_guard::RecursionState;
 use crate::tools::SchemaDict;
 
 mod any;
@@ -263,7 +263,7 @@ impl SchemaValidator {
             self_instance: None,
         };
 
-        let guard = &mut RecursionGuard::default();
+        let guard = &mut RecursionState::default();
         let mut state = ValidationState::new(extra, guard);
         self.validator
             .validate_assignment(py, obj, field_name, field_value, &mut state)
@@ -280,7 +280,7 @@ impl SchemaValidator {
             context,
             self_instance: None,
         };
-        let recursion_guard = &mut RecursionGuard::default();
+        let recursion_guard = &mut RecursionState::default();
         let mut state = ValidationState::new(extra, recursion_guard);
         let r = self.validator.default_value(py, None::<i64>, &mut state);
         match r {
@@ -326,7 +326,7 @@ impl SchemaValidator {
     where
         's: 'data,
     {
-        let mut recursion_guard = RecursionGuard::default();
+        let mut recursion_guard = RecursionState::default();
         let mut state = ValidationState::new(
             Extra::new(strict, from_attributes, context, self_instance, input_type),
             &mut recursion_guard,
@@ -378,7 +378,7 @@ impl<'py> SelfValidator<'py> {
     }
 
     pub fn validate_schema(&self, py: Python<'py>, schema: &'py PyAny, strict: Option<bool>) -> PyResult<&'py PyAny> {
-        let mut recursion_guard = RecursionGuard::default();
+        let mut recursion_guard = RecursionState::default();
         let mut state = ValidationState::new(
             Extra::new(strict, None, None, None, InputType::Python),
             &mut recursion_guard,
diff --git a/src/validators/validation_state.rs b/src/validators/validation_state.rs
index aacd7d2af9a..4f241d7680c 100644
--- a/src/validators/validation_state.rs
+++ b/src/validators/validation_state.rs
@@ -1,4 +1,4 @@
-use crate::recursion_guard::RecursionGuard;
+use crate::recursion_guard::{ContainsRecursionState, RecursionState};
 
 use super::Extra;
 
@@ -10,14 +10,14 @@ pub enum Exactness {
 }
 
 pub struct ValidationState<'a> {
-    pub recursion_guard: &'a mut RecursionGuard,
+    pub recursion_guard: &'a mut RecursionState,
     pub exactness: Option<Exactness>,
     // deliberately make Extra readonly
     extra: Extra<'a>,
 }
 
 impl<'a> ValidationState<'a> {
-    pub fn new(extra: Extra<'a>, recursion_guard: &'a mut RecursionGuard) -> Self {
+    pub fn new(extra: Extra<'a>, recursion_guard: &'a mut RecursionState) -> Self {
         Self {
             recursion_guard, // Don't care about exactness unless doing union validation
             exactness: None,
@@ -84,6 +84,12 @@ impl<'a> ValidationState<'a> {
     }
 }
 
+impl ContainsRecursionState for ValidationState<'_> {
+    fn access_recursion_state<R>(&mut self, f: impl FnOnce(&mut RecursionState) -> R) -> R {
+        f(self.recursion_guard)
+    }
+}
+
 pub struct ValidationStateWithReboundExtra<'state, 'a> {
     state: &'state mut ValidationState<'a>,
     old_extra: Extra<'a>,