diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index d28bac2ec..9722669cc 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -19,6 +19,7 @@ use serde::{ser::Error, Serialize, Serializer}; use crate::errors::{ py_err_string, ErrorType, ErrorTypeDefaults, InputValue, ToErrorValue, ValError, ValLineError, ValResult, }; +use crate::py_gc::PyGcTraverse; use crate::tools::{extract_i64, new_py_string, py_err}; use crate::validators::{CombinedValidator, Exactness, ValidationState, Validator}; @@ -327,6 +328,15 @@ pub enum GenericIterator<'data> { JsonArray(GenericJsonIterator<'data>), } +impl PyGcTraverse for GenericIterator<'_> { + fn py_gc_traverse(&self, visit: &pyo3::PyVisit<'_>) -> Result<(), pyo3::PyTraverseError> { + if let Self::PyIterator(iter) = self { + iter.py_gc_traverse(visit)?; + } + Ok(()) + } +} + impl GenericIterator<'_> { pub(crate) fn into_static(self) -> GenericIterator<'static> { match self { @@ -385,6 +395,8 @@ impl GenericPyIterator { } } +impl_py_gc_traverse!(GenericPyIterator { obj, iter }); + #[derive(Debug, Clone)] pub struct GenericJsonIterator<'data> { array: JsonArray<'data>, diff --git a/src/validators/generator.rs b/src/validators/generator.rs index 904f4a0ef..de9949a8c 100644 --- a/src/validators/generator.rs +++ b/src/validators/generator.rs @@ -1,11 +1,12 @@ use std::fmt; use std::sync::Arc; -use pyo3::prelude::*; use pyo3::types::PyDict; +use pyo3::{prelude::*, PyTraverseError, PyVisit}; use crate::errors::{ErrorType, LocItem, ValError, ValResult}; use crate::input::{BorrowInput, GenericIterator, Input}; +use crate::py_gc::PyGcTraverse; use crate::recursion_guard::RecursionState; use crate::tools::SchemaDict; use crate::ValidationError; @@ -201,6 +202,12 @@ impl ValidatorIterator { fn __str__(&self) -> String { self.__repr__() } + + fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + self.iterator.py_gc_traverse(&visit)?; + self.validator.py_gc_traverse(&visit)?; + Ok(()) + } } /// Owned validator wrapper for use in generators in functions, this can be passed back to python diff --git a/tests/test_garbage_collection.py b/tests/test_garbage_collection.py index 97107e61b..f4e178f26 100644 --- a/tests/test_garbage_collection.py +++ b/tests/test_garbage_collection.py @@ -1,6 +1,6 @@ import gc import platform -from typing import Any +from typing import Any, Iterable from weakref import WeakValueDictionary import pytest @@ -79,3 +79,42 @@ class MyModel(BaseModel): gc.collect(2) assert len(cache) == 0 + + +@pytest.mark.xfail( + condition=platform.python_implementation() == 'PyPy', reason='https://foss.heptapod.net/pypy/pypy/-/issues/3899' +) +def test_gc_validator_iterator() -> None: + # test for https://github.com/pydantic/pydantic/issues/9243 + class MyModel: + iter: Iterable[int] + + v = SchemaValidator( + core_schema.model_schema( + MyModel, + core_schema.model_fields_schema( + {'iter': core_schema.model_field(core_schema.generator_schema(core_schema.int_schema()))} + ), + ), + ) + + class MyIterable: + def __iter__(self): + return self + + def __next__(self): + raise StopIteration() + + cache: 'WeakValueDictionary[int, Any]' = WeakValueDictionary() + + for _ in range(10_000): + iterable = MyIterable() + cache[id(iterable)] = iterable + v.validate_python({'iter': iterable}) + del iterable + + gc.collect(0) + gc.collect(1) + gc.collect(2) + + assert len(cache) == 0