@@ -9,6 +9,29 @@ import (
9
9
"github.com/patrickmn/go-cache"
10
10
)
11
11
12
+ const (
13
+ strictPrefix = "strict"
14
+ hardcorePrefix = "hardcore"
15
+ )
16
+
17
+ var _counters = & sync.Pool {
18
+ New : func () interface {} {
19
+ i := & atomic.Int64 {}
20
+ i .Store (0 )
21
+ return i
22
+ },
23
+ }
24
+
25
+ func getCounter () * atomic.Int64 {
26
+ got := _counters .Get ().(* atomic.Int64 )
27
+ got .Store (0 )
28
+ return got
29
+ }
30
+
31
+ func putCounter (i * atomic.Int64 ) {
32
+ _counters .Put (i )
33
+ }
34
+
12
35
/*NewDefaultLimiter returns a ratelimiter with default settings without Strict mode.
13
36
* Default window: 25 seconds
14
37
* Default burst: 25 requests */
@@ -70,28 +93,40 @@ func NewHardcoreLimiter(window int, burst int) *Limiter {
70
93
return l
71
94
}
72
95
96
+ // ResetItem removes an Identity from the limiter's cache.
97
+ // This effectively resets the rate limit for the Identity.
73
98
func (q * Limiter ) ResetItem (from Identity ) {
74
99
q .Patrons .Delete (from .UniqueKey ())
75
- q .debugPrintf ("ratelimit for %s has been reset" , from .UniqueKey ())
100
+ q .debugPrintf (msgRateLimitedRst , from .UniqueKey ())
101
+ }
102
+
103
+ func (q * Limiter ) onEvict (src string , count interface {}) {
104
+ q .debugPrintf (msgRateLimitExpired , src , count )
105
+ putCounter (count .(* atomic.Int64 ))
106
+
76
107
}
77
108
78
109
func newLimiter (policy Policy ) * Limiter {
79
110
window := time .Duration (policy .Window ) * time .Second
80
- return & Limiter {
111
+ q := & Limiter {
81
112
Ruleset : policy ,
82
113
Patrons : cache .New (window , time .Duration (policy .Window )* time .Second ),
83
- known : make (map [interface {}]* int64 ),
114
+ known : make (map [interface {}]* atomic. Int64 ),
84
115
RWMutex : & sync.RWMutex {},
85
116
debugMutex : & sync.RWMutex {},
86
117
debug : DebugDisabled ,
87
118
}
119
+ q .Patrons .OnEvicted (q .onEvict )
120
+ return q
88
121
}
89
122
90
- func intPtr (i int64 ) * int64 {
91
- return & i
123
+ func intPtr (i int64 ) * atomic.Int64 {
124
+ a := getCounter ()
125
+ a .Store (i )
126
+ return a
92
127
}
93
128
94
- func (q * Limiter ) getHitsPtr (src string ) * int64 {
129
+ func (q * Limiter ) getHitsPtr (src string ) * atomic. Int64 {
95
130
q .RLock ()
96
131
if _ , ok := q .known [src ]; ok {
97
132
oldPtr := q .known [src ]
@@ -100,29 +135,29 @@ func (q *Limiter) getHitsPtr(src string) *int64 {
100
135
}
101
136
q .RUnlock ()
102
137
q .Lock ()
103
- newPtr := intPtr ( 0 )
138
+ newPtr := getCounter ( )
104
139
q .known [src ] = newPtr
105
140
q .Unlock ()
106
141
return newPtr
107
142
}
108
143
109
144
func (q * Limiter ) strictLogic (src string , count * atomic.Int64 ) {
110
145
knownHits := q .getHitsPtr (src )
111
- atomic . AddInt64 ( knownHits , 1 )
146
+ knownHits . Add ( 1 )
112
147
var extwindow int64
113
- prefix := "hardcore"
148
+ prefix := hardcorePrefix
114
149
switch {
115
150
case q .Ruleset .Hardcore && q .Ruleset .Window > 1 :
116
- extwindow = atomic . LoadInt64 ( knownHits ) * q .Ruleset .Window
151
+ extwindow = knownHits . Load ( ) * q .Ruleset .Window
117
152
case q .Ruleset .Hardcore && q .Ruleset .Window <= 1 :
118
- extwindow = atomic . LoadInt64 ( knownHits ) * 2
153
+ extwindow = knownHits . Load ( ) * 2
119
154
case ! q .Ruleset .Hardcore :
120
- prefix = "strict"
121
- extwindow = atomic . LoadInt64 ( knownHits ) + q .Ruleset .Window
155
+ prefix = strictPrefix
156
+ extwindow = knownHits . Load ( ) + q .Ruleset .Window
122
157
}
123
158
exttime := time .Duration (extwindow ) * time .Second
124
159
_ = q .Patrons .Replace (src , count , exttime )
125
- q .debugPrintf ("%s ratelimit for %s: last count %d. time: %s" , prefix , src , count , exttime )
160
+ q .debugPrintf (msgRateLimitStrict , prefix , src , count . Load () , exttime )
126
161
}
127
162
128
163
func (q * Limiter ) CheckStringer (from fmt.Stringer ) bool {
@@ -132,21 +167,17 @@ func (q *Limiter) CheckStringer(from fmt.Stringer) bool {
132
167
133
168
// Check checks and increments an Identities UniqueKey() output against a list of cached strings to determine and raise it's ratelimitting status.
134
169
func (q * Limiter ) Check (from Identity ) (limited bool ) {
135
- var aval any
136
170
var count int64
137
- var ok bool
138
- aval , ok = q .Patrons .Get (from .UniqueKey ())
171
+ aval , ok := q .Patrons .Get (from .UniqueKey ())
139
172
switch {
140
173
case ! ok :
141
- q .debugPrintf ("ratelimit %s (new) " , from .UniqueKey ())
142
- aval = & atomic.Int64 {}
143
- aval .(* atomic.Int64 ).Store (1 )
174
+ q .debugPrintf (msgRateLimitedNew , from .UniqueKey ())
175
+ aval = intPtr (1 )
144
176
// We can't reproduce this throwing an error, we can only assume that the key is new.
145
177
_ = q .Patrons .Add (from .UniqueKey (), aval , time .Duration (q .Ruleset .Window )* time .Second )
146
178
return false
147
- case ok && aval != nil :
179
+ case aval != nil :
148
180
count = aval .(* atomic.Int64 ).Add (1 )
149
- _ = q .Patrons .Replace (from .UniqueKey (), aval , time .Duration (q .Ruleset .Window )* time .Second )
150
181
if count < q .Ruleset .Burst {
151
182
return false
152
183
}
@@ -155,14 +186,12 @@ func (q *Limiter) Check(from Identity) (limited bool) {
155
186
q .strictLogic (from .UniqueKey (), aval .(* atomic.Int64 ))
156
187
return true
157
188
}
158
- q .debugPrintf ("ratelimit %s: last count %d. time: %s" ,
159
- from .UniqueKey (), count , time .Duration (q .Ruleset .Window )* time .Second )
189
+ q .debugPrintf (msgRateLimited , from .UniqueKey (), count , time .Duration (q .Ruleset .Window )* time .Second )
160
190
return true
161
191
}
162
192
163
193
// Peek checks an Identities UniqueKey() output against a list of cached strings to determine ratelimitting status without adding to its request count.
164
194
func (q * Limiter ) Peek (from Identity ) bool {
165
- q .Patrons .DeleteExpired ()
166
195
if ct , ok := q .Patrons .Get (from .UniqueKey ()); ok {
167
196
count := ct .(* atomic.Int64 ).Load ()
168
197
if count > q .Ruleset .Burst {
0 commit comments