11use ahash:: AHashSet ;
2+ use std:: mem:: MaybeUninit ;
23
34type RecursionKey = (
45 // Identifier for the input object, e.g. the id() of a Python dict
@@ -13,56 +14,147 @@ type RecursionKey = (
1314/// It's used in `validators/definition` to detect when a reference is reused within itself.
1415#[ derive( Debug , Clone , Default ) ]
1516pub struct RecursionGuard {
16- ids : Option < AHashSet < RecursionKey > > ,
17+ ids : RecursionStack ,
1718 // depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just
1819 // use one number for all validators
19- depth : u16 ,
20+ depth : u8 ,
2021}
2122
2223// A hard limit to avoid stack overflows when rampant recursion occurs
23- pub const RECURSION_GUARD_LIMIT : u16 = if cfg ! ( any( target_family = "wasm" , all( windows, PyPy ) ) ) {
24+ pub const RECURSION_GUARD_LIMIT : u8 = if cfg ! ( any( target_family = "wasm" , all( windows, PyPy ) ) ) {
2425 // wasm and windows PyPy have very limited stack sizes
25- 50
26+ 49
2627} else if cfg ! ( any( PyPy , windows) ) {
2728 // PyPy and Windows in general have more restricted stack space
28- 100
29+ 99
2930} else {
3031 255
3132} ;
3233
3334impl RecursionGuard {
34- // insert a new id into the set, return whether the set already had the id in it
35- pub fn contains_or_insert ( & mut self , obj_id : usize , node_id : usize ) -> bool {
36- match self . ids {
37- // https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
38- // "If the set did not have this value present, `true` is returned."
39- Some ( ref mut set) => !set. insert ( ( obj_id, node_id) ) ,
40- None => {
41- let mut set: AHashSet < RecursionKey > = AHashSet :: with_capacity ( 10 ) ;
42- set. insert ( ( obj_id, node_id) ) ;
43- self . ids = Some ( set) ;
44- false
45- }
46- }
35+ // insert a new value
36+ // * return `false` if the stack already had it in it
37+ // * return `true` if the stack didn't have it in it and it was inserted
38+ pub fn insert ( & mut self , obj_id : usize , node_id : usize ) -> bool {
39+ self . ids . insert ( ( obj_id, node_id) )
4740 }
4841
4942 // see #143 this is used as a backup in case the identity check recursion guard fails
5043 #[ must_use]
44+ #[ cfg( any( target_family = "wasm" , windows, PyPy ) ) ]
5145 pub fn incr_depth ( & mut self ) -> bool {
52- self . depth += 1 ;
53- self . depth >= RECURSION_GUARD_LIMIT
46+ // use saturating_add as it's faster (since there's no error path)
47+ // and the RECURSION_GUARD_LIMIT check will be hit before it overflows
48+ debug_assert ! ( RECURSION_GUARD_LIMIT < 255 ) ;
49+ self . depth = self . depth . saturating_add ( 1 ) ;
50+ self . depth > RECURSION_GUARD_LIMIT
51+ }
52+
53+ #[ must_use]
54+ #[ cfg( not( any( target_family = "wasm" , windows, PyPy ) ) ) ]
55+ pub fn incr_depth ( & mut self ) -> bool {
56+ debug_assert_eq ! ( RECURSION_GUARD_LIMIT , 255 ) ;
57+ // use checked_add to check if we've hit the limit
58+ if let Some ( depth) = self . depth . checked_add ( 1 ) {
59+ self . depth = depth;
60+ false
61+ } else {
62+ true
63+ }
5464 }
5565
5666 pub fn decr_depth ( & mut self ) {
57- self . depth -= 1 ;
67+ // for the same reason as incr_depth, use saturating_sub
68+ self . depth = self . depth . saturating_sub ( 1 ) ;
5869 }
5970
6071 pub fn remove ( & mut self , obj_id : usize , node_id : usize ) {
61- match self . ids {
62- Some ( ref mut set) => {
63- set. remove ( & ( obj_id, node_id) ) ;
72+ self . ids . remove ( & ( obj_id, node_id) ) ;
73+ }
74+ }
75+
76+ // trial and error suggests this is a good value, going higher causes array lookups to get significantly slower
77+ const ARRAY_SIZE : usize = 16 ;
78+
79+ #[ derive( Debug , Clone ) ]
80+ enum RecursionStack {
81+ Array {
82+ data : [ MaybeUninit < RecursionKey > ; ARRAY_SIZE ] ,
83+ len : usize ,
84+ } ,
85+ Set ( AHashSet < RecursionKey > ) ,
86+ }
87+
88+ impl Default for RecursionStack {
89+ fn default ( ) -> Self {
90+ Self :: Array {
91+ data : std:: array:: from_fn ( |_| MaybeUninit :: uninit ( ) ) ,
92+ len : 0 ,
93+ }
94+ }
95+ }
96+
97+ impl RecursionStack {
98+ // insert a new value
99+ // * return `false` if the stack already had it in it
100+ // * return `true` if the stack didn't have it in it and it was inserted
101+ pub fn insert ( & mut self , v : RecursionKey ) -> bool {
102+ match self {
103+ Self :: Array { data, len } => {
104+ if * len < ARRAY_SIZE {
105+ for value in data. iter ( ) . take ( * len) {
106+ // Safety: reading values within bounds
107+ if unsafe { value. assume_init ( ) } == v {
108+ return false ;
109+ }
110+ }
111+
112+ data[ * len] . write ( v) ;
113+ * len += 1 ;
114+ true
115+ } else {
116+ let mut set = AHashSet :: with_capacity ( ARRAY_SIZE + 1 ) ;
117+ for existing in data. iter ( ) {
118+ // Safety: the array is fully initialized
119+ set. insert ( unsafe { existing. assume_init ( ) } ) ;
120+ }
121+ let inserted = set. insert ( v) ;
122+ * self = Self :: Set ( set) ;
123+ inserted
124+ }
125+ }
126+ // https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
127+ // "If the set did not have this value present, `true` is returned."
128+ Self :: Set ( set) => set. insert ( v) ,
129+ }
130+ }
131+
132+ pub fn remove ( & mut self , v : & RecursionKey ) {
133+ match self {
134+ Self :: Array { data, len } => {
135+ * len = len. checked_sub ( 1 ) . expect ( "remove from empty recursion guard" ) ;
136+ // Safety: this is reading what was the back of the initialized array
137+ let removed = unsafe { data. get_unchecked_mut ( * len) } ;
138+ assert ! ( unsafe { removed. assume_init_ref( ) } == v, "remove did not match insert" ) ;
139+ // this should compile away to a noop
140+ unsafe { std:: ptr:: drop_in_place ( removed. as_mut_ptr ( ) ) }
141+ }
142+ Self :: Set ( set) => {
143+ set. remove ( v) ;
144+ }
145+ }
146+ }
147+ }
148+
149+ impl Drop for RecursionStack {
150+ fn drop ( & mut self ) {
151+ // This should compile away to a noop as Recursion>Key doesn't implement Drop, but it seemed
152+ // desirable to leave this in for safety in case that should change in the future
153+ if let Self :: Array { data, len } = self {
154+ for value in data. iter_mut ( ) . take ( * len) {
155+ // Safety: reading values within bounds
156+ unsafe { std:: ptr:: drop_in_place ( value. as_mut_ptr ( ) ) } ;
64157 }
65- None => unreachable ! ( ) ,
66- } ;
158+ }
67159 }
68160}
0 commit comments