Skip to content

Commit cf8dddd

Browse files
committed
codify the stack-based nature of the guard
1 parent e1190ed commit cf8dddd

File tree

5 files changed

+83
-66
lines changed

5 files changed

+83
-66
lines changed

src/recursion_guard.rs

Lines changed: 63 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use ahash::AHashSet;
2-
use std::hash::Hash;
2+
use std::mem::MaybeUninit;
33

44
type RecursionKey = (
55
// Identifier for the input object, e.g. the id() of a Python dict
@@ -14,7 +14,7 @@ type RecursionKey = (
1414
/// It's used in `validators/definition` to detect when a reference is reused within itself.
1515
#[derive(Debug, Clone, Default)]
1616
pub struct RecursionGuard {
17-
ids: SmallContainer<RecursionKey>,
17+
ids: RecursionStack,
1818
// depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just
1919
// use one number for all validators
2020
depth: u8,
@@ -33,10 +33,10 @@ pub const RECURSION_GUARD_LIMIT: u8 = if cfg!(any(target_family = "wasm", all(wi
3333

3434
impl RecursionGuard {
3535
// insert a new value
36-
// * return `None` if the array/set already had it in it
37-
// * return `Some(index)` if the array didn't have it in it and it was inserted
38-
pub fn contains_or_insert(&mut self, obj_id: usize, node_id: usize) -> Option<usize> {
39-
self.ids.contains_or_insert((obj_id, node_id))
36+
// * return `false` if the stack already had it in it
37+
// * return `true` if the stack didn't have it in it and it was inserted
38+
pub fn insert(&mut self, obj_id: usize, node_id: usize) -> bool {
39+
self.ids.insert((obj_id, node_id))
4040
}
4141

4242
// see #143 this is used as a backup in case the identity check recursion guard fails
@@ -68,76 +68,93 @@ impl RecursionGuard {
6868
self.depth = self.depth.saturating_sub(1);
6969
}
7070

71-
pub fn remove(&mut self, obj_id: usize, node_id: usize, index: usize) {
72-
self.ids.remove(&(obj_id, node_id), index);
71+
pub fn remove(&mut self, obj_id: usize, node_id: usize) {
72+
self.ids.remove(&(obj_id, node_id));
7373
}
7474
}
7575

7676
// trial and error suggests this is a good value, going higher causes array lookups to get significantly slower
7777
const ARRAY_SIZE: usize = 16;
7878

7979
#[derive(Debug, Clone)]
80-
enum SmallContainer<T> {
81-
Array([Option<T>; ARRAY_SIZE]),
82-
Set(AHashSet<T>),
80+
enum RecursionStack {
81+
Array {
82+
data: [MaybeUninit<RecursionKey>; ARRAY_SIZE],
83+
len: usize,
84+
},
85+
Set(AHashSet<RecursionKey>),
8386
}
8487

85-
impl<T: Copy> Default for SmallContainer<T> {
88+
impl Default for RecursionStack {
8689
fn default() -> Self {
87-
Self::Array([None; ARRAY_SIZE])
90+
Self::Array {
91+
data: std::array::from_fn(|_| MaybeUninit::uninit()),
92+
len: 0,
93+
}
8894
}
8995
}
9096

91-
impl<T: Eq + Hash + Clone> SmallContainer<T> {
97+
impl RecursionStack {
9298
// insert a new value
93-
// * return `None` if the array/set already had it in it
94-
// * return `Some(index)` if the array didn't have it in it and it was inserted
95-
pub fn contains_or_insert(&mut self, v: T) -> Option<usize> {
99+
// * return `false` if the stack already had it in it
100+
// * return `true` if the stack didn't have it in it and it was inserted
101+
pub fn insert(&mut self, v: RecursionKey) -> bool {
96102
match self {
97-
Self::Array(array) => {
98-
for (index, op_value) in array.iter_mut().enumerate() {
99-
if let Some(existing) = op_value {
100-
if existing == &v {
101-
return None;
103+
Self::Array { data, len } => {
104+
if *len < ARRAY_SIZE {
105+
for value in data.iter().take(*len) {
106+
// Safety: reading values within bounds
107+
if unsafe { value.assume_init() } == v {
108+
return false;
102109
}
103-
} else {
104-
*op_value = Some(v);
105-
return Some(index);
106110
}
107-
}
108111

109-
// No array slots exist; convert to set
110-
let mut set: AHashSet<T> = AHashSet::with_capacity(ARRAY_SIZE + 1);
111-
for existing in array.iter_mut() {
112-
set.insert(existing.take().unwrap());
112+
data[*len].write(v);
113+
*len += 1;
114+
true
115+
} else {
116+
let mut set = AHashSet::with_capacity(ARRAY_SIZE + 1);
117+
for existing in data.iter() {
118+
// Safety: the array is fully initialized
119+
set.insert(unsafe { existing.assume_init() });
120+
}
121+
let inserted = set.insert(v);
122+
*self = Self::Set(set);
123+
inserted
113124
}
114-
set.insert(v);
115-
*self = Self::Set(set);
116-
// id doesn't matter here as we'll be removing from a set
117-
Some(0)
118125
}
119126
// https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
120127
// "If the set did not have this value present, `true` is returned."
121-
Self::Set(set) => {
122-
if set.insert(v) {
123-
// again id doesn't matter here as we'll be removing from a set
124-
Some(0)
125-
} else {
126-
None
127-
}
128-
}
128+
Self::Set(set) => set.insert(v),
129129
}
130130
}
131131

132-
pub fn remove(&mut self, v: &T, index: usize) {
132+
pub fn remove(&mut self, v: &RecursionKey) {
133133
match self {
134-
Self::Array(array) => {
135-
debug_assert!(array[index].as_ref() == Some(v), "remove did not match insert");
136-
array[index] = None;
134+
Self::Array { data, len } => {
135+
*len = len.checked_sub(1).expect("remove from empty recursion guard");
136+
assert!(
137+
// Safety: this is reading the back of the initialized array
138+
unsafe { data[*len].assume_init() } == *v,
139+
"remove did not match insert"
140+
);
137141
}
138142
Self::Set(set) => {
139143
set.remove(v);
140144
}
141145
}
142146
}
143147
}
148+
149+
impl Drop for RecursionStack {
150+
fn drop(&mut self) {
151+
// This should compile away to a noop as Recursion>Key doesn't implement Drop, but it seemed
152+
// desirable to leave this in for safety in case that should change in the future
153+
if let Self::Array { data, len } = self {
154+
for value in data.iter_mut().take(*len) {
155+
// Safety: reading values within bounds
156+
unsafe { std::ptr::drop_in_place(value.as_mut_ptr()) };
157+
}
158+
}
159+
}
160+
}

src/serializers/extra.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -345,24 +345,24 @@ pub struct SerRecursionGuard {
345345
}
346346

347347
impl SerRecursionGuard {
348-
pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<(usize, usize)> {
348+
pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<usize> {
349349
let id = value.as_ptr() as usize;
350350
let mut guard = self.guard.borrow_mut();
351351

352-
if let Some(insert_index) = guard.contains_or_insert(id, def_ref_id) {
352+
if guard.insert(id, def_ref_id) {
353353
if guard.incr_depth() {
354354
Err(PyValueError::new_err("Circular reference detected (depth exceeded)"))
355355
} else {
356-
Ok((id, insert_index))
356+
Ok(id)
357357
}
358358
} else {
359359
Err(PyValueError::new_err("Circular reference detected (id repeated)"))
360360
}
361361
}
362362

363-
pub fn pop(&self, id: usize, def_ref_id: usize, insert_index: usize) {
363+
pub fn pop(&self, id: usize, def_ref_id: usize) {
364364
let mut guard = self.guard.borrow_mut();
365365
guard.decr_depth();
366-
guard.remove(id, def_ref_id, insert_index);
366+
guard.remove(id, def_ref_id);
367367
}
368368
}

src/serializers/infer.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ pub(crate) fn infer_to_python_known(
4545
extra: &Extra,
4646
) -> PyResult<PyObject> {
4747
let py = value.py();
48-
let (value_id, insert_index) = match extra.rec_guard.add(value, INFER_DEF_REF_ID) {
48+
let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID) {
4949
Ok(id) => id,
5050
Err(e) => {
5151
return match extra.mode {
@@ -226,7 +226,7 @@ pub(crate) fn infer_to_python_known(
226226
if let Some(fallback) = extra.fallback {
227227
let next_value = fallback.call1((value,))?;
228228
let next_result = infer_to_python(next_value, include, exclude, extra);
229-
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index);
229+
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
230230
return next_result;
231231
} else if extra.serialize_unknown {
232232
serialize_unknown(value).into_py(py)
@@ -284,15 +284,15 @@ pub(crate) fn infer_to_python_known(
284284
if let Some(fallback) = extra.fallback {
285285
let next_value = fallback.call1((value,))?;
286286
let next_result = infer_to_python(next_value, include, exclude, extra);
287-
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index);
287+
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
288288
return next_result;
289289
}
290290
value.into_py(py)
291291
}
292292
_ => value.into_py(py),
293293
},
294294
};
295-
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index);
295+
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
296296
Ok(value)
297297
}
298298

@@ -351,7 +351,7 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
351351
exclude: Option<&PyAny>,
352352
extra: &Extra,
353353
) -> Result<S::Ok, S::Error> {
354-
let (value_id, insert_index) = match extra.rec_guard.add(value, INFER_DEF_REF_ID).map_err(py_err_se_err) {
354+
let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID).map_err(py_err_se_err) {
355355
Ok(v) => v,
356356
Err(e) => {
357357
return if extra.serialize_unknown {
@@ -534,7 +534,7 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
534534
if let Some(fallback) = extra.fallback {
535535
let next_value = fallback.call1((value,)).map_err(py_err_se_err)?;
536536
let next_result = infer_serialize(next_value, serializer, include, exclude, extra);
537-
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index);
537+
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
538538
return next_result;
539539
} else if extra.serialize_unknown {
540540
serializer.serialize_str(&serialize_unknown(value))
@@ -548,7 +548,7 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
548548
}
549549
}
550550
};
551-
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index);
551+
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
552552
ser_result
553553
}
554554

src/serializers/type_serializers/definitions.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ impl TypeSerializer for DefinitionRefSerializer {
7070
) -> PyResult<PyObject> {
7171
self.definition.read(|comb_serializer| {
7272
let comb_serializer = comb_serializer.unwrap();
73-
let (value_id, insert_index) = extra.rec_guard.add(value, self.definition.id())?;
73+
let value_id = extra.rec_guard.add(value, self.definition.id())?;
7474
let r = comb_serializer.to_python(value, include, exclude, extra);
75-
extra.rec_guard.pop(value_id, self.definition.id(), insert_index);
75+
extra.rec_guard.pop(value_id, self.definition.id());
7676
r
7777
})
7878
}
@@ -91,12 +91,12 @@ impl TypeSerializer for DefinitionRefSerializer {
9191
) -> Result<S::Ok, S::Error> {
9292
self.definition.read(|comb_serializer| {
9393
let comb_serializer = comb_serializer.unwrap();
94-
let (value_id, insert_index) = extra
94+
let value_id = extra
9595
.rec_guard
9696
.add(value, self.definition.id())
9797
.map_err(py_err_se_err)?;
9898
let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra);
99-
extra.rec_guard.pop(value_id, self.definition.id(), insert_index);
99+
extra.rec_guard.pop(value_id, self.definition.id());
100100
r
101101
})
102102
}

src/validators/definitions.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,12 @@ impl Validator for DefinitionRefValidator {
7676
self.definition.read(|validator| {
7777
let validator = validator.unwrap();
7878
if let Some(id) = input.identity() {
79-
if let Some(insert_index) = state.recursion_guard.contains_or_insert(id, self.definition.id()) {
79+
if state.recursion_guard.insert(id, self.definition.id()) {
8080
if state.recursion_guard.incr_depth() {
8181
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
8282
}
8383
let output = validator.validate(py, input, state);
84-
state.recursion_guard.remove(id, self.definition.id(), insert_index);
84+
state.recursion_guard.remove(id, self.definition.id());
8585
state.recursion_guard.decr_depth();
8686
output
8787
} else {
@@ -105,12 +105,12 @@ impl Validator for DefinitionRefValidator {
105105
self.definition.read(|validator| {
106106
let validator = validator.unwrap();
107107
if let Some(id) = obj.identity() {
108-
if let Some(insert_index) = state.recursion_guard.contains_or_insert(id, self.definition.id()) {
108+
if state.recursion_guard.insert(id, self.definition.id()) {
109109
if state.recursion_guard.incr_depth() {
110110
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj));
111111
}
112112
let output = validator.validate_assignment(py, obj, field_name, field_value, state);
113-
state.recursion_guard.remove(id, self.definition.id(), insert_index);
113+
state.recursion_guard.remove(id, self.definition.id());
114114
state.recursion_guard.decr_depth();
115115
output
116116
} else {

0 commit comments

Comments
 (0)