Skip to content

Commit

Permalink
added RedisOptions and InMemoryOptions, added Skip(c *gin.Context) bo…
Browse files Browse the repository at this point in the history
…ol to the Store interface, and changed compatibility to 1.17.
  • Loading branch information
Raven0213 committed Jul 27, 2022
1 parent 1366354 commit d7398b8
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 35 deletions.
30 changes: 23 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,13 @@ func errorHandler(c *gin.Context, remaining time.Duration) {
func main() {
server := gin.Default()
// This makes it so each ip can only make 5 requests per second
store := ratelimit.RedisStore(time.Second, 5, redis.NewClient(&redis.Options{
Addr: "localhost:7680",
}), false)
store := ratelimit.RedisStore(&ratelimit.RedisOptions{
RedisClient: redis.NewClient(&redis.Options{
Addr: "localhost:7680",
}),
Rate: time.Second,
Limit: 5,
})
mw := ratelimit.RateLimiter(keyFunc, errorHandler, store)
server.GET("/", mw, func(c *gin.Context) {
c.String(200, "Hello World")
Expand Down Expand Up @@ -75,7 +79,10 @@ func errorHandler(c *gin.Context, remaining time.Duration) {
func main() {
server := gin.Default()
// This makes it so each ip can only make 5 requests per second
store := ratelimit.InMemoryStore(time.Second, 5)
store := ratelimit.InMemoryStore(&ratelimit.InMemoryOptions{
Rate: time.Second,
Limit: 5,
})
mw := ratelimit.RateLimiter(keyFunc, errorHandler, store)
server.GET("/", mw, func(c *gin.Context) {
c.String(200, "Hello World")
Expand All @@ -91,18 +98,27 @@ Custom Store Example
```go
package main

import "time"
import (
"github.com/gin-gonic/gin"
"time"
)

type CustomStore struct {
}

// Your store must have a method called Limit that takes a key and returns a bool
// Your store must have a method called Limit that takes a key and returns a bool, time.Duration
func (s *CustomStore) Limit(key string) (bool, time.Duration) {
// Do your rate limit logic, and return true if the user went over the rate limit, otherwise return false
// Return the amount of time the client needs to wait to make a new request
if UserWentOverLimit {
return true
return true, remaining
}
return false, remaining
}

// Your store must have a method called Skip that takes a *gin.Context and returns a bool
func (s *CustomStore) Skip(c *gin.Context) bool {
// return true if you dont want this request to count toward the users rate limit
return false
}
```
28 changes: 28 additions & 0 deletions gin_rate_limit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package ratelimit

import (
"github.com/gin-gonic/gin"
"time"
)

type Store interface {
Limit(key string) (bool, time.Duration)
Skip(c *gin.Context) bool
}

func RateLimiter(keyFunc func(c *gin.Context) string, errorHandler func(c *gin.Context, remaining time.Duration), s Store) func(ctx *gin.Context) {
return func(c *gin.Context) {
if s.Skip(c) {
c.Next()
return
}
key := keyFunc(c)
limited, remaining := s.Limit(key)
if limited {
errorHandler(c, remaining)
c.Abort()
} else {
c.Next()
}
}
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/JGLTechnologies/gin-rate-limit

go 1.18
go 1.17

require (
github.com/gin-gonic/gin v1.8.1
Expand Down
42 changes: 20 additions & 22 deletions in_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ import (

type user struct {
ts int64
tokens int
tokens uint
}

func clearInBackground(data *sync.Map, rate int64) {
for {
data.Range(func(k, v any) bool {
data.Range(func(k, v interface{}) bool {
if v.(user).ts+rate <= time.Now().Unix() {
data.Delete(k)
}
Expand All @@ -23,13 +23,14 @@ func clearInBackground(data *sync.Map, rate int64) {
}
}

type InMemoryStoreType struct {
type inMemoryStoreType struct {
rate int64
limit int
limit uint
data *sync.Map
skip func(c *gin.Context) bool
}

func (s *InMemoryStoreType) Limit(key string) (bool, time.Duration) {
func (s *inMemoryStoreType) Limit(key string) (bool, time.Duration) {
var u user
m, ok := s.data.Load(key)
if !ok {
Expand All @@ -50,26 +51,23 @@ func (s *InMemoryStoreType) Limit(key string) (bool, time.Duration) {
return false, time.Duration(0)
}

type store interface {
Limit(key string) (bool, time.Duration)
func (s *inMemoryStoreType) Skip(c *gin.Context) bool {
if s.skip != nil {
return s.skip(c)
} else {
return false
}
}

type InMemoryOptions struct {
Rate time.Duration
Limit uint
Skip func(c *gin.Context) bool
}

func InMemoryStore(rate time.Duration, limit int) *InMemoryStoreType {
func InMemoryStore(options *InMemoryOptions) Store {
data := &sync.Map{}
store := InMemoryStoreType{int64(rate.Seconds()), limit, data}
store := inMemoryStoreType{int64(options.Rate.Seconds()), options.Limit, data, options.Skip}
go clearInBackground(data, store.rate)
return &store
}

func RateLimiter(keyFunc func(c *gin.Context) string, errorHandler func(c *gin.Context, remaining time.Duration), s store) func(ctx *gin.Context) {
return func(c *gin.Context) {
key := keyFunc(c)
limited, remaining := s.Limit(key)
if limited {
errorHandler(c, remaining)
c.Abort()
} else {
c.Next()
}
}
}
35 changes: 30 additions & 5 deletions redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@ package ratelimit

import (
"context"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"time"
)

type RedisStoreType struct {
type redisStoreType struct {
rate int64
limit int
limit uint
client redis.UniversalClient
ctx context.Context
panicOnErr bool
skip func(c *gin.Context) bool
}

func (s *RedisStoreType) Limit(key string) (bool, time.Duration) {
func (s *redisStoreType) Limit(key string) (bool, time.Duration) {
p := s.client.Pipeline()
defer p.Close()
cmds, _ := s.client.Pipelined(s.ctx, func(pipeliner redis.Pipeliner) error {
Expand Down Expand Up @@ -61,6 +63,29 @@ func (s *RedisStoreType) Limit(key string) (bool, time.Duration) {
return false, time.Duration(0)
}

func RedisStore(rate time.Duration, limit int, redisClient redis.UniversalClient, panicOnErr bool) *RedisStoreType {
return &RedisStoreType{client: redisClient, rate: int64(rate.Seconds()), limit: limit, ctx: context.TODO(), panicOnErr: panicOnErr}
func (s *redisStoreType) Skip(c *gin.Context) bool {
if s.skip != nil {
return s.skip(c)
} else {
return false
}
}

type RedisOptions struct {
Rate time.Duration
Limit uint
RedisClient redis.UniversalClient
Skip func(c *gin.Context) bool
PanicOnErr bool
}

func RedisStore(options *RedisOptions) Store {
return &redisStoreType{
client: options.RedisClient,
rate: int64(options.Rate.Seconds()),
limit: options.Limit,
ctx: context.TODO(),
panicOnErr: options.PanicOnErr,
skip: options.Skip,
}
}

0 comments on commit d7398b8

Please sign in to comment.