Skip to content

Commit

Permalink
Perf: Increase speed and reduce memory allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
yunginnanet committed Jun 26, 2024
2 parents 2e6d1f7 + fe846fd commit 65c0f29
Show file tree
Hide file tree
Showing 6 changed files with 362 additions and 79 deletions.
18 changes: 16 additions & 2 deletions debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,24 @@ import (
"sync/atomic"
)

const (
msgRateLimitExpired = "ratelimit (expired): %s | last count [%d]"
msgDebugEnabled = "rate5 debug enabled"
msgRateLimitedRst = "ratelimit for %s has been reset"
msgRateLimitedNew = "ratelimit %s (new) "
msgRateLimited = "ratelimit %s: last count %d. time: %s"
msgRateLimitStrict = "%s ratelimit for %s: last count %d. time: %s"
)

func (q *Limiter) debugPrintf(format string, a ...interface{}) {
if atomic.CompareAndSwapUint32(&q.debug, DebugDisabled, DebugDisabled) {
return
}
if len(a) == 2 {
if _, ok := a[1].(*atomic.Int64); ok {
a[1] = a[1].(*atomic.Int64).Load()
}
}
msg := fmt.Sprintf(format, a...)
select {
case q.debugChannel <- msg:
Expand All @@ -21,15 +35,15 @@ func (q *Limiter) debugPrintf(format string, a ...interface{}) {

func (q *Limiter) setDebugEvict() {
q.Patrons.OnEvicted(func(src string, count interface{}) {
q.debugPrintf("ratelimit (expired): %s | last count [%d]", src, count)
q.debugPrintf(msgRateLimitExpired, src, count.(*atomic.Int64).Load())
})
}

func (q *Limiter) SetDebug(on bool) {
switch on {
case true:
if atomic.CompareAndSwapUint32(&q.debug, DebugDisabled, DebugEnabled) {
q.debugPrintf("rate5 debug enabled")
q.debugPrintf(msgDebugEnabled)
}
case false:
atomic.CompareAndSwapUint32(&q.debug, DebugEnabled, DebugDisabled)
Expand Down
3 changes: 2 additions & 1 deletion models.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package rate5
import (
"fmt"
"sync"
"sync/atomic"

"github.com/patrickmn/go-cache"
)
Expand Down Expand Up @@ -46,7 +47,7 @@ type Limiter struct {
debug uint32
debugChannel chan string
debugLost int64
known map[interface{}]*int64
known map[interface{}]*atomic.Int64
debugMutex *sync.RWMutex
*sync.RWMutex
}
Expand Down
96 changes: 65 additions & 31 deletions ratelimiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,29 @@ import (
"github.com/patrickmn/go-cache"
)

const (
strictPrefix = "strict"
hardcorePrefix = "hardcore"
)

var _counters = &sync.Pool{
New: func() interface{} {
i := &atomic.Int64{}
i.Store(0)
return i
},
}

func getCounter() *atomic.Int64 {
got := _counters.Get().(*atomic.Int64)
got.Store(0)
return got
}

func putCounter(i *atomic.Int64) {
_counters.Put(i)
}

/*NewDefaultLimiter returns a ratelimiter with default settings without Strict mode.
* Default window: 25 seconds
* Default burst: 25 requests */
Expand Down Expand Up @@ -70,28 +93,40 @@ func NewHardcoreLimiter(window int, burst int) *Limiter {
return l
}

// ResetItem removes an Identity from the limiter's cache.
// This effectively resets the rate limit for the Identity.
func (q *Limiter) ResetItem(from Identity) {
q.Patrons.Delete(from.UniqueKey())
q.debugPrintf("ratelimit for %s has been reset", from.UniqueKey())
q.debugPrintf(msgRateLimitedRst, from.UniqueKey())
}

func (q *Limiter) onEvict(src string, count interface{}) {
q.debugPrintf(msgRateLimitExpired, src, count)
putCounter(count.(*atomic.Int64))

}

func newLimiter(policy Policy) *Limiter {
window := time.Duration(policy.Window) * time.Second
return &Limiter{
q := &Limiter{
Ruleset: policy,
Patrons: cache.New(window, time.Duration(policy.Window)*time.Second),
known: make(map[interface{}]*int64),
known: make(map[interface{}]*atomic.Int64),
RWMutex: &sync.RWMutex{},
debugMutex: &sync.RWMutex{},
debug: DebugDisabled,
}
q.Patrons.OnEvicted(q.onEvict)
return q
}

func intPtr(i int64) *int64 {
return &i
func intPtr(i int64) *atomic.Int64 {
a := getCounter()
a.Store(i)
return a
}

func (q *Limiter) getHitsPtr(src string) *int64 {
func (q *Limiter) getHitsPtr(src string) *atomic.Int64 {
q.RLock()
if _, ok := q.known[src]; ok {
oldPtr := q.known[src]
Expand All @@ -100,29 +135,29 @@ func (q *Limiter) getHitsPtr(src string) *int64 {
}
q.RUnlock()
q.Lock()
newPtr := intPtr(0)
newPtr := getCounter()
q.known[src] = newPtr
q.Unlock()
return newPtr
}

func (q *Limiter) strictLogic(src string, count int64) {
func (q *Limiter) strictLogic(src string, count *atomic.Int64) {
knownHits := q.getHitsPtr(src)
atomic.AddInt64(knownHits, 1)
knownHits.Add(1)
var extwindow int64
prefix := "hardcore"
prefix := hardcorePrefix
switch {
case q.Ruleset.Hardcore && q.Ruleset.Window > 1:
extwindow = atomic.LoadInt64(knownHits) * q.Ruleset.Window
extwindow = knownHits.Load() * q.Ruleset.Window
case q.Ruleset.Hardcore && q.Ruleset.Window <= 1:
extwindow = atomic.LoadInt64(knownHits) * 2
extwindow = knownHits.Load() * 2
case !q.Ruleset.Hardcore:
prefix = "strict"
extwindow = atomic.LoadInt64(knownHits) + q.Ruleset.Window
prefix = strictPrefix
extwindow = knownHits.Load() + q.Ruleset.Window
}
exttime := time.Duration(extwindow) * time.Second
_ = q.Patrons.Replace(src, count, exttime)
q.debugPrintf("%s ratelimit for %s: last count %d. time: %s", prefix, src, count, exttime)
q.debugPrintf(msgRateLimitStrict, prefix, src, count.Load(), exttime)
}

func (q *Limiter) CheckStringer(from fmt.Stringer) bool {
Expand All @@ -133,33 +168,32 @@ func (q *Limiter) CheckStringer(from fmt.Stringer) bool {
// Check checks and increments an Identities UniqueKey() output against a list of cached strings to determine and raise it's ratelimitting status.
func (q *Limiter) Check(from Identity) (limited bool) {
var count int64
var err error
src := from.UniqueKey()
count, err = q.Patrons.IncrementInt64(src, 1)
if err != nil {
// IncrementInt64 should only error if the value is not an int64, so we can assume it's a new key.
q.debugPrintf("ratelimit %s (new) ", src)
aval, ok := q.Patrons.Get(from.UniqueKey())
switch {
case !ok:
q.debugPrintf(msgRateLimitedNew, from.UniqueKey())
aval = intPtr(1)
// We can't reproduce this throwing an error, we can only assume that the key is new.
_ = q.Patrons.Add(src, int64(1), time.Duration(q.Ruleset.Window)*time.Second)
return false
}
if count < q.Ruleset.Burst {
_ = q.Patrons.Add(from.UniqueKey(), aval, time.Duration(q.Ruleset.Window)*time.Second)
return false
case aval != nil:
count = aval.(*atomic.Int64).Add(1)
if count < q.Ruleset.Burst {
return false
}
}
if q.Ruleset.Strict {
q.strictLogic(src, count)
} else {
q.debugPrintf("ratelimit %s: last count %d. time: %s",
src, count, time.Duration(q.Ruleset.Window)*time.Second)
q.strictLogic(from.UniqueKey(), aval.(*atomic.Int64))
return true
}
q.debugPrintf(msgRateLimited, from.UniqueKey(), count, time.Duration(q.Ruleset.Window)*time.Second)
return true
}

// Peek checks an Identities UniqueKey() output against a list of cached strings to determine ratelimitting status without adding to its request count.
func (q *Limiter) Peek(from Identity) bool {
q.Patrons.DeleteExpired()
if ct, ok := q.Patrons.Get(from.UniqueKey()); ok {
count := ct.(int64)
count := ct.(*atomic.Int64).Load()
if count > q.Ruleset.Burst {
return true
}
Expand Down
33 changes: 17 additions & 16 deletions ratelimiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/binary"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
)
Expand Down Expand Up @@ -155,23 +156,23 @@ func Test_NewLimiter(t *testing.T) {

func Test_NewDefaultStrictLimiter(t *testing.T) {
limiter := NewDefaultStrictLimiter()
ctx, cancel := context.WithCancel(context.Background())
go watchDebug(ctx, limiter, t)
// ctx, cancel := context.WithCancel(context.Background())
// go watchDebug(ctx, limiter, t)
time.Sleep(25 * time.Millisecond)
for n := 0; n < 25; n++ {
limiter.Check(dummyTicker)
}
peekCheckLimited(t, limiter, false, false)
limiter.Check(dummyTicker)
peekCheckLimited(t, limiter, true, false)
cancel()
// cancel()
limiter = nil
}

func Test_NewStrictLimiter(t *testing.T) {
limiter := NewStrictLimiter(5, 1)
ctx, cancel := context.WithCancel(context.Background())
go watchDebug(ctx, limiter, t)
// ctx, cancel := context.WithCancel(context.Background())
// go watchDebug(ctx, limiter, t)
limiter.Check(dummyTicker)
peekCheckLimited(t, limiter, false, false)
limiter.Check(dummyTicker)
Expand All @@ -190,7 +191,7 @@ func Test_NewStrictLimiter(t *testing.T) {
peekCheckLimited(t, limiter, true, false)
time.Sleep(8 * time.Second)
peekCheckLimited(t, limiter, false, false)
cancel()
// cancel()
limiter = nil
}

Expand Down Expand Up @@ -305,7 +306,7 @@ testloop:
if ci, ok = limiter.Patrons.Get(rp.UniqueKey()); !ok {
t.Fatal("randomPatron does not exist in ratelimiter at all!")
}
ct := ci.(int64)
ct := ci.(*atomic.Int64).Load()
if limiter.Peek(rp) && !shouldLimit {
t.Logf("(%d goroutines running)", runtime.NumGoroutine())
// runtime.Breakpoint()
Expand Down Expand Up @@ -349,8 +350,8 @@ func Test_debugChannelOverflow(t *testing.T) {

func BenchmarkCheck(b *testing.B) {
b.StopTimer()
b.ReportAllocs()
limiter := NewDefaultLimiter()
b.ReportAllocs()
b.StartTimer()
for n := 0; n < b.N; n++ {
limiter.Check(dummyTicker)
Expand All @@ -359,8 +360,8 @@ func BenchmarkCheck(b *testing.B) {

func BenchmarkCheckHardcore(b *testing.B) {
b.StopTimer()
b.ReportAllocs()
limiter := NewHardcoreLimiter(25, 25)
b.ReportAllocs()
b.StartTimer()
for n := 0; n < b.N; n++ {
limiter.Check(dummyTicker)
Expand All @@ -369,8 +370,8 @@ func BenchmarkCheckHardcore(b *testing.B) {

func BenchmarkCheckStrict(b *testing.B) {
b.StopTimer()
b.ReportAllocs()
limiter := NewStrictLimiter(25, 25)
b.ReportAllocs()
b.StartTimer()
for n := 0; n < b.N; n++ {
limiter.Check(dummyTicker)
Expand All @@ -379,8 +380,8 @@ func BenchmarkCheckStrict(b *testing.B) {

func BenchmarkCheckStringer(b *testing.B) {
b.StopTimer()
b.ReportAllocs()
limiter := NewDefaultLimiter()
b.ReportAllocs()
b.StartTimer()
for n := 0; n < b.N; n++ {
limiter.CheckStringer(dummyTicker)
Expand All @@ -389,8 +390,8 @@ func BenchmarkCheckStringer(b *testing.B) {

func BenchmarkPeek(b *testing.B) {
b.StopTimer()
b.ReportAllocs()
limiter := NewDefaultLimiter()
b.ReportAllocs()
b.StartTimer()
for n := 0; n < b.N; n++ {
limiter.Peek(dummyTicker)
Expand All @@ -399,8 +400,8 @@ func BenchmarkPeek(b *testing.B) {

func BenchmarkConcurrentCheck(b *testing.B) {
b.StopTimer()
b.ReportAllocs()
limiter := NewDefaultLimiter()
b.ReportAllocs()
b.StartTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
Expand All @@ -411,8 +412,8 @@ func BenchmarkConcurrentCheck(b *testing.B) {

func BenchmarkConcurrentSetAndCheckHardcore(b *testing.B) {
b.StopTimer()
b.ReportAllocs()
limiter := NewHardcoreLimiter(25, 25)
b.ReportAllocs()
b.StartTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
Expand All @@ -423,8 +424,8 @@ func BenchmarkConcurrentSetAndCheckHardcore(b *testing.B) {

func BenchmarkConcurrentSetAndCheckStrict(b *testing.B) {
b.StopTimer()
b.ReportAllocs()
limiter := NewDefaultStrictLimiter()
b.ReportAllocs()
b.StartTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
Expand All @@ -435,8 +436,8 @@ func BenchmarkConcurrentSetAndCheckStrict(b *testing.B) {

func BenchmarkConcurrentPeek(b *testing.B) {
b.StopTimer()
b.ReportAllocs()
limiter := NewDefaultLimiter()
b.ReportAllocs()
b.StartTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
Expand Down
Loading

0 comments on commit 65c0f29

Please sign in to comment.