diff --git a/GinRateLimit.go b/GinRateLimit.go index 3824ed3..0c54bd6 100644 --- a/GinRateLimit.go +++ b/GinRateLimit.go @@ -11,15 +11,14 @@ type user struct { tokens int } -func clearInBackground(data map[string]*user, rate int64, mutex *sync.Mutex) { +func clearInBackground(data *sync.Map, rate int64) { for { - mutex.Lock() - for k, v := range data { - if v.ts+rate <= time.Now().Unix() { - delete(data, k) + data.Range(func(k, v any) bool { + if v.(user).ts+rate <= time.Now().Unix() { + data.Delete(k) } - } - mutex.Unlock() + return true + }) time.Sleep(time.Minute) } } @@ -27,18 +26,16 @@ func clearInBackground(data map[string]*user, rate int64, mutex *sync.Mutex) { type InMemoryStoreType struct { rate int64 limit int - data map[string]*user - mutex *sync.Mutex + data *sync.Map } func (s *InMemoryStoreType) Limit(key string) (bool, time.Duration) { - s.mutex.Lock() - defer s.mutex.Unlock() - _, ok := s.data[key] + _, ok := s.data.Load(key) if !ok { - s.data[key] = &user{time.Now().Unix(), s.limit} + s.data.Store(key, user{time.Now().Unix(), s.limit}) } - u := s.data[key] + m, _ := s.data.Load(key) + u := m.(user) if u.ts+s.rate <= time.Now().Unix() { u.tokens = s.limit } @@ -48,7 +45,7 @@ func (s *InMemoryStoreType) Limit(key string) (bool, time.Duration) { } u.tokens-- u.ts = time.Now().Unix() - s.data[key] = u + s.data.Store(key, u) return false, time.Duration(0) } @@ -57,10 +54,9 @@ type store interface { } func InMemoryStore(rate time.Duration, limit int) *InMemoryStoreType { - mutex := &sync.Mutex{} - data := map[string]*user{} - store := InMemoryStoreType{int64(rate.Seconds()), limit, data, mutex} - go clearInBackground(data, store.rate, mutex) + data := &sync.Map{} + store := InMemoryStoreType{int64(rate.Seconds()), limit, data} + go clearInBackground(data, store.rate) return &store } diff --git a/redis.go b/redis.go index 4ab9af2..3b5b107 100644 --- a/redis.go +++ b/redis.go @@ -31,7 +31,8 @@ func (s *RedisStoreType) Limit(key string) (bool, time.Duration) { hits = 0 } if ts+s.rate <= time.Now().Unix() { - p.Set(s.ctx, key+"hits", 0, time.Duration(0)) + hits = 0 + p.Set(s.ctx, key+"hits", hits, time.Duration(0)) } remaining := time.Duration((s.rate - (time.Now().Unix() - ts)) * time.Second.Nanoseconds()) if hits >= int64(s.limit) {