11use ahash:: AHashSet ;
2- use std:: hash :: Hash ;
2+ use std:: mem :: MaybeUninit ;
33
44type RecursionKey = (
55 // Identifier for the input object, e.g. the id() of a Python dict
@@ -14,7 +14,7 @@ type RecursionKey = (
1414/// It's used in `validators/definition` to detect when a reference is reused within itself.
1515#[ derive( Debug , Clone , Default ) ]
1616pub struct RecursionGuard {
17- ids : SmallContainer < RecursionKey > ,
17+ ids : RecursionStack ,
1818 // depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just
1919 // use one number for all validators
2020 depth : u8 ,
@@ -33,10 +33,10 @@ pub const RECURSION_GUARD_LIMIT: u8 = if cfg!(any(target_family = "wasm", all(wi
3333
3434impl RecursionGuard {
3535 // insert a new value
36- // * return `None ` if the array/set already had it in it
37- // * return `Some(index) ` if the array didn't have it in it and it was inserted
38- pub fn contains_or_insert ( & mut self , obj_id : usize , node_id : usize ) -> Option < usize > {
39- self . ids . contains_or_insert ( ( obj_id, node_id) )
36+ // * return `false ` if the stack already had it in it
37+ // * return `true ` if the stack didn't have it in it and it was inserted
38+ pub fn insert ( & mut self , obj_id : usize , node_id : usize ) -> bool {
39+ self . ids . insert ( ( obj_id, node_id) )
4040 }
4141
4242 // see #143 this is used as a backup in case the identity check recursion guard fails
@@ -68,76 +68,93 @@ impl RecursionGuard {
6868 self . depth = self . depth . saturating_sub ( 1 ) ;
6969 }
7070
71- pub fn remove ( & mut self , obj_id : usize , node_id : usize , index : usize ) {
72- self . ids . remove ( & ( obj_id, node_id) , index ) ;
71+ pub fn remove ( & mut self , obj_id : usize , node_id : usize ) {
72+ self . ids . remove ( & ( obj_id, node_id) ) ;
7373 }
7474}
7575
7676// trial and error suggests this is a good value, going higher causes array lookups to get significantly slower
7777const ARRAY_SIZE : usize = 16 ;
7878
7979#[ derive( Debug , Clone ) ]
80- enum SmallContainer < T > {
81- Array ( [ Option < T > ; ARRAY_SIZE ] ) ,
82- Set ( AHashSet < T > ) ,
80+ enum RecursionStack {
81+ Array {
82+ data : [ MaybeUninit < RecursionKey > ; ARRAY_SIZE ] ,
83+ len : usize ,
84+ } ,
85+ Set ( AHashSet < RecursionKey > ) ,
8386}
8487
85- impl < T : Copy > Default for SmallContainer < T > {
88+ impl Default for RecursionStack {
8689 fn default ( ) -> Self {
87- Self :: Array ( [ None ; ARRAY_SIZE ] )
90+ Self :: Array {
91+ data : std:: array:: from_fn ( |_| MaybeUninit :: uninit ( ) ) ,
92+ len : 0 ,
93+ }
8894 }
8995}
9096
91- impl < T : Eq + Hash + Clone > SmallContainer < T > {
97+ impl RecursionStack {
9298 // insert a new value
93- // * return `None ` if the array/set already had it in it
94- // * return `Some(index) ` if the array didn't have it in it and it was inserted
95- pub fn contains_or_insert ( & mut self , v : T ) -> Option < usize > {
99+ // * return `false ` if the stack already had it in it
100+ // * return `true ` if the stack didn't have it in it and it was inserted
101+ pub fn insert ( & mut self , v : RecursionKey ) -> bool {
96102 match self {
97- Self :: Array ( array) => {
98- for ( index, op_value) in array. iter_mut ( ) . enumerate ( ) {
99- if let Some ( existing) = op_value {
100- if existing == & v {
101- return None ;
103+ Self :: Array { data, len } => {
104+ if * len < ARRAY_SIZE {
105+ for value in data. iter ( ) . take ( * len) {
106+ // Safety: reading values within bounds
107+ if unsafe { value. assume_init ( ) } == v {
108+ return false ;
102109 }
103- } else {
104- * op_value = Some ( v) ;
105- return Some ( index) ;
106110 }
107- }
108111
109- // No array slots exist; convert to set
110- let mut set: AHashSet < T > = AHashSet :: with_capacity ( ARRAY_SIZE + 1 ) ;
111- for existing in array. iter_mut ( ) {
112- set. insert ( existing. take ( ) . unwrap ( ) ) ;
112+ data[ * len] . write ( v) ;
113+ * len += 1 ;
114+ true
115+ } else {
116+ let mut set = AHashSet :: with_capacity ( ARRAY_SIZE + 1 ) ;
117+ for existing in data. iter ( ) {
118+ // Safety: the array is fully initialized
119+ set. insert ( unsafe { existing. assume_init ( ) } ) ;
120+ }
121+ let inserted = set. insert ( v) ;
122+ * self = Self :: Set ( set) ;
123+ inserted
113124 }
114- set. insert ( v) ;
115- * self = Self :: Set ( set) ;
116- // id doesn't matter here as we'll be removing from a set
117- Some ( 0 )
118125 }
119126 // https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
120127 // "If the set did not have this value present, `true` is returned."
121- Self :: Set ( set) => {
122- if set. insert ( v) {
123- // again id doesn't matter here as we'll be removing from a set
124- Some ( 0 )
125- } else {
126- None
127- }
128- }
128+ Self :: Set ( set) => set. insert ( v) ,
129129 }
130130 }
131131
132- pub fn remove ( & mut self , v : & T , index : usize ) {
132+ pub fn remove ( & mut self , v : & RecursionKey ) {
133133 match self {
134- Self :: Array ( array) => {
135- debug_assert ! ( array[ index] . as_ref( ) == Some ( v) , "remove did not match insert" ) ;
136- array[ index] = None ;
134+ Self :: Array { data, len } => {
135+ * len = len. checked_sub ( 1 ) . expect ( "remove from empty recursion guard" ) ;
136+ assert ! (
137+ // Safety: this is reading the back of the initialized array
138+ unsafe { data[ * len] . assume_init( ) } == * v,
139+ "remove did not match insert"
140+ ) ;
137141 }
138142 Self :: Set ( set) => {
139143 set. remove ( v) ;
140144 }
141145 }
142146 }
143147}
148+
149+ impl Drop for RecursionStack {
150+ fn drop ( & mut self ) {
151+ // This should compile away to a noop as Recursion>Key doesn't implement Drop, but it seemed
152+ // desirable to leave this in for safety in case that should change in the future
153+ if let Self :: Array { data, len } = self {
154+ for value in data. iter_mut ( ) . take ( * len) {
155+ // Safety: reading values within bounds
156+ unsafe { std:: ptr:: drop_in_place ( value. as_mut_ptr ( ) ) } ;
157+ }
158+ }
159+ }
160+ }
0 commit comments