Skip to content

Commit

Permalink
Added BeforeResponse to Options, Changed Store.Limit() to take in a k…
Browse files Browse the repository at this point in the history
…ey, *gin.Context and return ratelimit.Info. ErrorHandler now takes in Info instead of remaining time. Moved Skip from Options into The store options.

Check README.md to see the updates.
  • Loading branch information
Nebulizer1213 committed Aug 5, 2022
1 parent fa48046 commit deed734
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 35 deletions.
28 changes: 17 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ func keyFunc(c *gin.Context) string {
return c.ClientIP()
}

func errorHandler(c *gin.Context, remaining time.Duration) {
c.String(429, "Too many requests. Try again in "+remaining.String())
func errorHandler(c *gin.Context, info ratelimit.Info) {
c.String(429, "Too many requests. Try again in "+time.Until(info.ResetTime).String())
}

func main() {
Expand Down Expand Up @@ -75,8 +75,8 @@ func keyFunc(c *gin.Context) string {
return c.ClientIP()
}

func errorHandler(c *gin.Context, remaining time.Duration) {
c.String(429, "Too many requests. Try again in "+remaining.String())
func errorHandler(c *gin.Context, info ratelimit.Info) {
c.String(429, "Too many requests. Try again in "+time.Until(info.ResetTime).String())
}

func main() {
Expand Down Expand Up @@ -106,20 +106,26 @@ Custom Store Example
package main

import (
"github.com/JGLTechnologies/gin-rate-limit"
"github.com/gin-gonic/gin"
"time"
)

type CustomStore struct {
}

// Your store must have a method called Limit that takes a key and returns a bool, time.Duration
func (s *CustomStore) Limit(key string) (bool, time.Duration) {
// Do your rate limit logic, and return true if the user went over the rate limit, otherwise return false
// Return the amount of time the client needs to wait to make a new request
// Your store must have a method called Limit that takes a key, *gin.Context and returns ratelimit.Info
func (s *CustomStore) Limit(key string, c *gin.Context) Info {
if UserWentOverLimit {
return true, remaining
return Info{
RateLimited: true,
ResetTime: reset,
RemainingHits: 0,
}
}
return Info{
RateLimited: false,
ResetTime: reset,
RemainingHits: remaining,
}
return false, remaining
}
```
39 changes: 26 additions & 13 deletions gin_rate_limit.go
Original file line number Diff line number Diff line change
@@ -1,44 +1,57 @@
package ratelimit

import (
"fmt"
"github.com/gin-gonic/gin"
"time"
)

type Info struct {
RateLimited bool
ResetTime time.Time
RemainingHits uint
}

type Store interface {
// Limit takes in a key and should return whether that key is allowed to make another request
Limit(key string) (bool, time.Duration)
// Limit takes in a key and *gin.Context and should return whether that key is allowed to make another request
Limit(key string, c *gin.Context) Info
}

type Options struct {
ErrorHandler func(*gin.Context, time.Duration)
ErrorHandler func(*gin.Context, Info)
KeyFunc func(*gin.Context) string
// a function that returns true if the request should not count toward the rate limit
Skip func(*gin.Context) bool
// a function that lets you check the rate limiting info and modify the response
BeforeResponse func(c *gin.Context, info Info)
}

// RateLimiter is a function to get gin.HandlerFunc
func RateLimiter(s Store, options *Options) gin.HandlerFunc {
if options.ErrorHandler == nil {
options.ErrorHandler = func(c *gin.Context, remaining time.Duration) {
c.Header("X-Rate-Limit-Reset", remaining.String())
options.ErrorHandler = func(c *gin.Context, info Info) {
c.Header("X-Rate-Limit-Reset", fmt.Sprintf("%.2f", time.Until(info.ResetTime).Seconds()))
c.String(429, "Too many requests")
}
}
if options.BeforeResponse == nil {
options.BeforeResponse = func(c *gin.Context, info Info) {
c.Header("X-Rate-Limit-Remaining", fmt.Sprintf("%v", info.RemainingHits))
c.Header("X-Rate-Limit-Reset", fmt.Sprintf("%.2f", time.Until(info.ResetTime).Seconds()))
}
}
if options.KeyFunc == nil {
options.KeyFunc = func(c *gin.Context) string {
return c.ClientIP() + c.FullPath()
}
}
return func(c *gin.Context) {
if options.Skip != nil && options.Skip(c) {
c.Next()
key := options.KeyFunc(c)
info := s.Limit(key, c)
options.BeforeResponse(c, info)
if c.IsAborted() {
return
}
key := options.KeyFunc(c)
limited, remaining := s.Limit(key)
if limited {
options.ErrorHandler(c, remaining)
if info.RateLimited {
options.ErrorHandler(c, info)
c.Abort()
} else {
c.Next()
Expand Down
28 changes: 23 additions & 5 deletions in_memory.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ratelimit

import (
"github.com/gin-gonic/gin"
"sync"
"time"
)
Expand All @@ -26,9 +27,10 @@ type inMemoryStoreType struct {
rate int64
limit uint
data *sync.Map
skip func(ctx *gin.Context) bool
}

func (s *inMemoryStoreType) Limit(key string) (bool, time.Duration) {
func (s *inMemoryStoreType) Limit(key string, c *gin.Context) Info {
var u user
m, ok := s.data.Load(key)
if !ok {
Expand All @@ -39,26 +41,42 @@ func (s *inMemoryStoreType) Limit(key string) (bool, time.Duration) {
if u.ts+s.rate <= time.Now().Unix() {
u.tokens = s.limit
}
remaining := time.Duration((s.rate - (time.Now().Unix() - u.ts)) * time.Second.Nanoseconds())
if s.skip != nil && s.skip(c) {
return Info{
RateLimited: false,
ResetTime: time.Now().Add(time.Duration((s.rate - (time.Now().Unix() - u.ts)) * time.Second.Nanoseconds())),
RemainingHits: u.tokens,
}
}
if u.tokens <= 0 {
return true, remaining
return Info{
RateLimited: true,
ResetTime: time.Now().Add(time.Duration((s.rate - (time.Now().Unix() - u.ts)) * time.Second.Nanoseconds())),
RemainingHits: 0,
}
}
u.tokens--
u.ts = time.Now().Unix()
s.data.Store(key, u)
return false, time.Duration(0)
return Info{
RateLimited: false,
ResetTime: time.Now().Add(time.Duration((s.rate - (time.Now().Unix() - u.ts)) * time.Second.Nanoseconds())),
RemainingHits: u.tokens,
}
}

type InMemoryOptions struct {
// the user can make Limit amount of requests every Rate
Rate time.Duration
// the amount of requests that can be made every Rate
Limit uint
// a function that returns true if the request should not count toward the rate limit
Skip func(*gin.Context) bool
}

func InMemoryStore(options *InMemoryOptions) Store {
data := &sync.Map{}
store := inMemoryStoreType{int64(options.Rate.Seconds()), options.Limit, data}
store := inMemoryStoreType{int64(options.Rate.Seconds()), options.Limit, data, options.Skip}
go clearInBackground(data, store.rate)
return &store
}
39 changes: 33 additions & 6 deletions redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type redisStoreType struct {
skip func(c *gin.Context) bool
}

func (s *redisStoreType) Limit(key string) (bool, time.Duration) {
func (s *redisStoreType) Limit(key string, c *gin.Context) Info {
p := s.client.Pipeline()
defer p.Close()
cmds, _ := s.client.Pipelined(s.ctx, func(pipeliner redis.Pipeliner) error {
Expand All @@ -36,18 +36,34 @@ func (s *redisStoreType) Limit(key string) (bool, time.Duration) {
hits = 0
p.Set(s.ctx, key+"hits", hits, time.Duration(0))
}
remaining := time.Duration((s.rate - (time.Now().Unix() - ts)) * time.Second.Nanoseconds())
if s.skip != nil && s.skip(c) {
return Info{
RateLimited: false,
ResetTime: time.Now().Add(time.Duration((s.rate - (time.Now().Unix() - ts)) * time.Second.Nanoseconds())),
RemainingHits: s.limit - uint(hits),
}
}
if hits >= int64(s.limit) {
_, err = p.Exec(s.ctx)
if err != nil {
if s.panicOnErr {
panic(err)
} else {
return false, time.Duration(0)
return Info{
RateLimited: false,
ResetTime: time.Now().Add(time.Duration((s.rate - (time.Now().Unix() - ts)) * time.Second.Nanoseconds())),
RemainingHits: 0,
}
}
}
return true, remaining
return Info{
RateLimited: true,
ResetTime: time.Now().Add(time.Duration((s.rate - (time.Now().Unix() - ts)) * time.Second.Nanoseconds())),
RemainingHits: 0,
}
}
ts = time.Now().Unix()
hits++
p.Incr(s.ctx, key+"hits")
p.Set(s.ctx, key+"ts", time.Now().Unix(), time.Duration(0))
p.Expire(s.ctx, key+"hits", time.Duration(int64(time.Second)*s.rate*2))
Expand All @@ -57,10 +73,18 @@ func (s *redisStoreType) Limit(key string) (bool, time.Duration) {
if s.panicOnErr {
panic(err)
} else {
return false, time.Duration(0)
return Info{
RateLimited: false,
ResetTime: time.Now().Add(time.Duration((s.rate - (time.Now().Unix() - ts)) * time.Second.Nanoseconds())),
RemainingHits: s.limit - uint(hits),
}
}
}
return false, time.Duration(0)
return Info{
RateLimited: false,
ResetTime: time.Now().Add(time.Duration((s.rate - (time.Now().Unix() - ts)) * time.Second.Nanoseconds())),
RemainingHits: s.limit - uint(hits),
}
}

type RedisOptions struct {
Expand All @@ -71,6 +95,8 @@ type RedisOptions struct {
RedisClient redis.UniversalClient
// should gin-rate-limit panic when there is an error with redis
PanicOnErr bool
// a function that returns true if the request should not count toward the rate limit
Skip func(*gin.Context) bool
}

func RedisStore(options *RedisOptions) Store {
Expand All @@ -80,5 +106,6 @@ func RedisStore(options *RedisOptions) Store {
limit: options.Limit,
ctx: context.TODO(),
panicOnErr: options.PanicOnErr,
skip: options.Skip,
}
}

0 comments on commit deed734

Please sign in to comment.