Skip to content

Commit fe846fd

Browse files
committed
Perf: Increase speed and reduce memory allocations
1 parent d3c6f59 commit fe846fd

File tree

6 files changed

+349
-72
lines changed

6 files changed

+349
-72
lines changed

Diff for: debug.go

+16-2
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,24 @@ import (
55
"sync/atomic"
66
)
77

8+
const (
9+
msgRateLimitExpired = "ratelimit (expired): %s | last count [%d]"
10+
msgDebugEnabled = "rate5 debug enabled"
11+
msgRateLimitedRst = "ratelimit for %s has been reset"
12+
msgRateLimitedNew = "ratelimit %s (new) "
13+
msgRateLimited = "ratelimit %s: last count %d. time: %s"
14+
msgRateLimitStrict = "%s ratelimit for %s: last count %d. time: %s"
15+
)
16+
817
func (q *Limiter) debugPrintf(format string, a ...interface{}) {
918
if atomic.CompareAndSwapUint32(&q.debug, DebugDisabled, DebugDisabled) {
1019
return
1120
}
21+
if len(a) == 2 {
22+
if _, ok := a[1].(*atomic.Int64); ok {
23+
a[1] = a[1].(*atomic.Int64).Load()
24+
}
25+
}
1226
msg := fmt.Sprintf(format, a...)
1327
select {
1428
case q.debugChannel <- msg:
@@ -21,15 +35,15 @@ func (q *Limiter) debugPrintf(format string, a ...interface{}) {
2135

2236
func (q *Limiter) setDebugEvict() {
2337
q.Patrons.OnEvicted(func(src string, count interface{}) {
24-
q.debugPrintf("ratelimit (expired): %s | last count [%d]", src, count)
38+
q.debugPrintf(msgRateLimitExpired, src, count.(*atomic.Int64).Load())
2539
})
2640
}
2741

2842
func (q *Limiter) SetDebug(on bool) {
2943
switch on {
3044
case true:
3145
if atomic.CompareAndSwapUint32(&q.debug, DebugDisabled, DebugEnabled) {
32-
q.debugPrintf("rate5 debug enabled")
46+
q.debugPrintf(msgDebugEnabled)
3347
}
3448
case false:
3549
atomic.CompareAndSwapUint32(&q.debug, DebugEnabled, DebugDisabled)

Diff for: models.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package rate5
33
import (
44
"fmt"
55
"sync"
6+
"sync/atomic"
67

78
"github.com/patrickmn/go-cache"
89
)
@@ -46,7 +47,7 @@ type Limiter struct {
4647
debug uint32
4748
debugChannel chan string
4849
debugLost int64
49-
known map[interface{}]*int64
50+
known map[interface{}]*atomic.Int64
5051
debugMutex *sync.RWMutex
5152
*sync.RWMutex
5253
}

Diff for: ratelimiter.go

+54-25
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,29 @@ import (
99
"github.com/patrickmn/go-cache"
1010
)
1111

12+
const (
13+
strictPrefix = "strict"
14+
hardcorePrefix = "hardcore"
15+
)
16+
17+
var _counters = &sync.Pool{
18+
New: func() interface{} {
19+
i := &atomic.Int64{}
20+
i.Store(0)
21+
return i
22+
},
23+
}
24+
25+
func getCounter() *atomic.Int64 {
26+
got := _counters.Get().(*atomic.Int64)
27+
got.Store(0)
28+
return got
29+
}
30+
31+
func putCounter(i *atomic.Int64) {
32+
_counters.Put(i)
33+
}
34+
1235
/*NewDefaultLimiter returns a ratelimiter with default settings without Strict mode.
1336
* Default window: 25 seconds
1437
* Default burst: 25 requests */
@@ -70,28 +93,40 @@ func NewHardcoreLimiter(window int, burst int) *Limiter {
7093
return l
7194
}
7295

96+
// ResetItem removes an Identity from the limiter's cache.
97+
// This effectively resets the rate limit for the Identity.
7398
func (q *Limiter) ResetItem(from Identity) {
7499
q.Patrons.Delete(from.UniqueKey())
75-
q.debugPrintf("ratelimit for %s has been reset", from.UniqueKey())
100+
q.debugPrintf(msgRateLimitedRst, from.UniqueKey())
101+
}
102+
103+
func (q *Limiter) onEvict(src string, count interface{}) {
104+
q.debugPrintf(msgRateLimitExpired, src, count)
105+
putCounter(count.(*atomic.Int64))
106+
76107
}
77108

78109
func newLimiter(policy Policy) *Limiter {
79110
window := time.Duration(policy.Window) * time.Second
80-
return &Limiter{
111+
q := &Limiter{
81112
Ruleset: policy,
82113
Patrons: cache.New(window, time.Duration(policy.Window)*time.Second),
83-
known: make(map[interface{}]*int64),
114+
known: make(map[interface{}]*atomic.Int64),
84115
RWMutex: &sync.RWMutex{},
85116
debugMutex: &sync.RWMutex{},
86117
debug: DebugDisabled,
87118
}
119+
q.Patrons.OnEvicted(q.onEvict)
120+
return q
88121
}
89122

90-
func intPtr(i int64) *int64 {
91-
return &i
123+
func intPtr(i int64) *atomic.Int64 {
124+
a := getCounter()
125+
a.Store(i)
126+
return a
92127
}
93128

94-
func (q *Limiter) getHitsPtr(src string) *int64 {
129+
func (q *Limiter) getHitsPtr(src string) *atomic.Int64 {
95130
q.RLock()
96131
if _, ok := q.known[src]; ok {
97132
oldPtr := q.known[src]
@@ -100,29 +135,29 @@ func (q *Limiter) getHitsPtr(src string) *int64 {
100135
}
101136
q.RUnlock()
102137
q.Lock()
103-
newPtr := intPtr(0)
138+
newPtr := getCounter()
104139
q.known[src] = newPtr
105140
q.Unlock()
106141
return newPtr
107142
}
108143

109144
func (q *Limiter) strictLogic(src string, count *atomic.Int64) {
110145
knownHits := q.getHitsPtr(src)
111-
atomic.AddInt64(knownHits, 1)
146+
knownHits.Add(1)
112147
var extwindow int64
113-
prefix := "hardcore"
148+
prefix := hardcorePrefix
114149
switch {
115150
case q.Ruleset.Hardcore && q.Ruleset.Window > 1:
116-
extwindow = atomic.LoadInt64(knownHits) * q.Ruleset.Window
151+
extwindow = knownHits.Load() * q.Ruleset.Window
117152
case q.Ruleset.Hardcore && q.Ruleset.Window <= 1:
118-
extwindow = atomic.LoadInt64(knownHits) * 2
153+
extwindow = knownHits.Load() * 2
119154
case !q.Ruleset.Hardcore:
120-
prefix = "strict"
121-
extwindow = atomic.LoadInt64(knownHits) + q.Ruleset.Window
155+
prefix = strictPrefix
156+
extwindow = knownHits.Load() + q.Ruleset.Window
122157
}
123158
exttime := time.Duration(extwindow) * time.Second
124159
_ = q.Patrons.Replace(src, count, exttime)
125-
q.debugPrintf("%s ratelimit for %s: last count %d. time: %s", prefix, src, count, exttime)
160+
q.debugPrintf(msgRateLimitStrict, prefix, src, count.Load(), exttime)
126161
}
127162

128163
func (q *Limiter) CheckStringer(from fmt.Stringer) bool {
@@ -132,21 +167,17 @@ func (q *Limiter) CheckStringer(from fmt.Stringer) bool {
132167

133168
// Check checks and increments an Identities UniqueKey() output against a list of cached strings to determine and raise it's ratelimitting status.
134169
func (q *Limiter) Check(from Identity) (limited bool) {
135-
var aval any
136170
var count int64
137-
var ok bool
138-
aval, ok = q.Patrons.Get(from.UniqueKey())
171+
aval, ok := q.Patrons.Get(from.UniqueKey())
139172
switch {
140173
case !ok:
141-
q.debugPrintf("ratelimit %s (new) ", from.UniqueKey())
142-
aval = &atomic.Int64{}
143-
aval.(*atomic.Int64).Store(1)
174+
q.debugPrintf(msgRateLimitedNew, from.UniqueKey())
175+
aval = intPtr(1)
144176
// We can't reproduce this throwing an error, we can only assume that the key is new.
145177
_ = q.Patrons.Add(from.UniqueKey(), aval, time.Duration(q.Ruleset.Window)*time.Second)
146178
return false
147-
case ok && aval != nil:
179+
case aval != nil:
148180
count = aval.(*atomic.Int64).Add(1)
149-
_ = q.Patrons.Replace(from.UniqueKey(), aval, time.Duration(q.Ruleset.Window)*time.Second)
150181
if count < q.Ruleset.Burst {
151182
return false
152183
}
@@ -155,14 +186,12 @@ func (q *Limiter) Check(from Identity) (limited bool) {
155186
q.strictLogic(from.UniqueKey(), aval.(*atomic.Int64))
156187
return true
157188
}
158-
q.debugPrintf("ratelimit %s: last count %d. time: %s",
159-
from.UniqueKey(), count, time.Duration(q.Ruleset.Window)*time.Second)
189+
q.debugPrintf(msgRateLimited, from.UniqueKey(), count, time.Duration(q.Ruleset.Window)*time.Second)
160190
return true
161191
}
162192

163193
// Peek checks an Identities UniqueKey() output against a list of cached strings to determine ratelimitting status without adding to its request count.
164194
func (q *Limiter) Peek(from Identity) bool {
165-
q.Patrons.DeleteExpired()
166195
if ct, ok := q.Patrons.Get(from.UniqueKey()); ok {
167196
count := ct.(*atomic.Int64).Load()
168197
if count > q.Ruleset.Burst {

Diff for: ratelimiter_test.go

+15-15
Original file line numberDiff line numberDiff line change
@@ -156,23 +156,23 @@ func Test_NewLimiter(t *testing.T) {
156156

157157
func Test_NewDefaultStrictLimiter(t *testing.T) {
158158
limiter := NewDefaultStrictLimiter()
159-
ctx, cancel := context.WithCancel(context.Background())
160-
go watchDebug(ctx, limiter, t)
159+
// ctx, cancel := context.WithCancel(context.Background())
160+
// go watchDebug(ctx, limiter, t)
161161
time.Sleep(25 * time.Millisecond)
162162
for n := 0; n < 25; n++ {
163163
limiter.Check(dummyTicker)
164164
}
165165
peekCheckLimited(t, limiter, false, false)
166166
limiter.Check(dummyTicker)
167167
peekCheckLimited(t, limiter, true, false)
168-
cancel()
168+
// cancel()
169169
limiter = nil
170170
}
171171

172172
func Test_NewStrictLimiter(t *testing.T) {
173173
limiter := NewStrictLimiter(5, 1)
174-
ctx, cancel := context.WithCancel(context.Background())
175-
go watchDebug(ctx, limiter, t)
174+
// ctx, cancel := context.WithCancel(context.Background())
175+
// go watchDebug(ctx, limiter, t)
176176
limiter.Check(dummyTicker)
177177
peekCheckLimited(t, limiter, false, false)
178178
limiter.Check(dummyTicker)
@@ -191,7 +191,7 @@ func Test_NewStrictLimiter(t *testing.T) {
191191
peekCheckLimited(t, limiter, true, false)
192192
time.Sleep(8 * time.Second)
193193
peekCheckLimited(t, limiter, false, false)
194-
cancel()
194+
// cancel()
195195
limiter = nil
196196
}
197197

@@ -350,8 +350,8 @@ func Test_debugChannelOverflow(t *testing.T) {
350350

351351
func BenchmarkCheck(b *testing.B) {
352352
b.StopTimer()
353-
b.ReportAllocs()
354353
limiter := NewDefaultLimiter()
354+
b.ReportAllocs()
355355
b.StartTimer()
356356
for n := 0; n < b.N; n++ {
357357
limiter.Check(dummyTicker)
@@ -360,8 +360,8 @@ func BenchmarkCheck(b *testing.B) {
360360

361361
func BenchmarkCheckHardcore(b *testing.B) {
362362
b.StopTimer()
363-
b.ReportAllocs()
364363
limiter := NewHardcoreLimiter(25, 25)
364+
b.ReportAllocs()
365365
b.StartTimer()
366366
for n := 0; n < b.N; n++ {
367367
limiter.Check(dummyTicker)
@@ -370,8 +370,8 @@ func BenchmarkCheckHardcore(b *testing.B) {
370370

371371
func BenchmarkCheckStrict(b *testing.B) {
372372
b.StopTimer()
373-
b.ReportAllocs()
374373
limiter := NewStrictLimiter(25, 25)
374+
b.ReportAllocs()
375375
b.StartTimer()
376376
for n := 0; n < b.N; n++ {
377377
limiter.Check(dummyTicker)
@@ -380,8 +380,8 @@ func BenchmarkCheckStrict(b *testing.B) {
380380

381381
func BenchmarkCheckStringer(b *testing.B) {
382382
b.StopTimer()
383-
b.ReportAllocs()
384383
limiter := NewDefaultLimiter()
384+
b.ReportAllocs()
385385
b.StartTimer()
386386
for n := 0; n < b.N; n++ {
387387
limiter.CheckStringer(dummyTicker)
@@ -390,8 +390,8 @@ func BenchmarkCheckStringer(b *testing.B) {
390390

391391
func BenchmarkPeek(b *testing.B) {
392392
b.StopTimer()
393-
b.ReportAllocs()
394393
limiter := NewDefaultLimiter()
394+
b.ReportAllocs()
395395
b.StartTimer()
396396
for n := 0; n < b.N; n++ {
397397
limiter.Peek(dummyTicker)
@@ -400,8 +400,8 @@ func BenchmarkPeek(b *testing.B) {
400400

401401
func BenchmarkConcurrentCheck(b *testing.B) {
402402
b.StopTimer()
403-
b.ReportAllocs()
404403
limiter := NewDefaultLimiter()
404+
b.ReportAllocs()
405405
b.StartTimer()
406406
b.RunParallel(func(pb *testing.PB) {
407407
for pb.Next() {
@@ -412,8 +412,8 @@ func BenchmarkConcurrentCheck(b *testing.B) {
412412

413413
func BenchmarkConcurrentSetAndCheckHardcore(b *testing.B) {
414414
b.StopTimer()
415-
b.ReportAllocs()
416415
limiter := NewHardcoreLimiter(25, 25)
416+
b.ReportAllocs()
417417
b.StartTimer()
418418
b.RunParallel(func(pb *testing.PB) {
419419
for pb.Next() {
@@ -424,8 +424,8 @@ func BenchmarkConcurrentSetAndCheckHardcore(b *testing.B) {
424424

425425
func BenchmarkConcurrentSetAndCheckStrict(b *testing.B) {
426426
b.StopTimer()
427-
b.ReportAllocs()
428427
limiter := NewDefaultStrictLimiter()
428+
b.ReportAllocs()
429429
b.StartTimer()
430430
b.RunParallel(func(pb *testing.PB) {
431431
for pb.Next() {
@@ -436,8 +436,8 @@ func BenchmarkConcurrentSetAndCheckStrict(b *testing.B) {
436436

437437
func BenchmarkConcurrentPeek(b *testing.B) {
438438
b.StopTimer()
439-
b.ReportAllocs()
440439
limiter := NewDefaultLimiter()
440+
b.ReportAllocs()
441441
b.StartTimer()
442442
b.RunParallel(func(pb *testing.PB) {
443443
for pb.Next() {

0 commit comments

Comments
 (0)