Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions temporal/internal/driver/redisv8/ratelimiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package redisv8

import (
"context"
"strconv"
"time"

"github.com/TykTechnologies/storage/temporal/temperr"
"github.com/go-redis/redis/v8"
)

// SetRollingWindow updates a sorted set in Redis to represent a rolling time window of values.
func (r *RedisV8) SetRollingWindow(ctx context.Context, now time.Time, keyName string, per int64, valueOverride string, pipeline bool) ([]string, error) {
if keyName == "" {
return []string{}, temperr.KeyEmpty
}
if per <= 0 {
return []string{}, temperr.InvalidPeriod
}

onePeriodAgo := now.Add(time.Duration(-1*per) * time.Second)
expire := time.Duration(per) * time.Second

memberValue := valueOverride
if valueOverride == "-1" {
memberValue = strconv.Itoa(int(now.UnixNano()))
}
element := redis.Z{
Score: float64(now.UnixNano()),
Member: memberValue,
}

var zrange *redis.StringSliceCmd
var err error

exec := r.client.TxPipelined
if pipeline {
exec = r.client.Pipelined
}

pipeFn := func(pipe redis.Pipeliner) error {
// removing elements outside the rolling window.
pipe.ZRemRangeByScore(ctx, keyName, "-inf", strconv.Itoa(int(onePeriodAgo.UnixNano())))
// getting the current range of values within the window.
zrange = pipe.ZRange(ctx, keyName, 0, -1)
// adding the new element and set the expiration time.
pipe.ZAdd(ctx, keyName, &element)
pipe.Expire(ctx, keyName, expire)
return nil
}

_, err = exec(ctx, pipeFn)
if err != nil {
return nil, err
}

return zrange.Result()
}

// GetRollingWindow removes a part of a sorted set in Redis and extracts a timed window of values.
func (r *RedisV8) GetRollingWindow(ctx context.Context, now time.Time, keyName string, per int64, pipeline bool) ([]string, error) {
if keyName == "" {
return []string{}, temperr.KeyEmpty
}
if per <= 0 {
return []string{}, temperr.InvalidPeriod
}

onePeriodAgo := now.Add(time.Duration(-1*per) * time.Second)
period := strconv.FormatInt(onePeriodAgo.UnixNano(), 10)

var zrange *redis.StringSliceCmd
var err error

exec := r.client.TxPipelined
if pipeline {
exec = r.client.Pipelined
}

pipeFn := func(pipe redis.Pipeliner) error {
// removing old elements outside the rolling window
pipe.ZRemRangeByScore(ctx, keyName, "-inf", period)
// retrieving the current range of values
zrange = pipe.ZRange(ctx, keyName, 0, -1)
return nil
}

_, err = exec(ctx, pipeFn)
if err != nil {
return nil, err
}

return zrange.Result()
}
8 changes: 8 additions & 0 deletions temporal/model/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,11 @@ type Message interface {
// - an empty string, returning an error
Payload() (string, error)
}

type RateLimit interface {
// SetRollingWindow sets the rolling window for a key with a per second rate limit
SetRollingWindow(ctx context.Context, now time.Time, keyName string, per int64,
valueOverride string, pipeline bool) ([]string, error)
// GetRollingWindow gets the rolling window for a key with a per second rate limit
GetRollingWindow(ctx context.Context, now time.Time, keyName string, per int64, pipeline bool) ([]string, error)
}
20 changes: 20 additions & 0 deletions temporal/ratelimiter/ratelimiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package ratelimiter

import (
"github.com/TykTechnologies/storage/temporal/internal/driver/redisv8"
"github.com/TykTechnologies/storage/temporal/model"
"github.com/TykTechnologies/storage/temporal/temperr"
)

type RateLimit = model.RateLimit

var _ RateLimit = (*redisv8.RedisV8)(nil)

func NewRateLimit(conn model.Connector) (RateLimit, error) {
switch conn.Type() {
case model.RedisV8Type:
return redisv8.NewRedisV8WithConnection(conn)
default:
return nil, temperr.InvalidHandlerType
}
}
187 changes: 187 additions & 0 deletions temporal/ratelimiter/ratelimiter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
package ratelimiter

import (
"context"
"testing"
"time"

"github.com/TykTechnologies/storage/temporal/flusher"
"github.com/TykTechnologies/storage/temporal/internal/testutil"
"github.com/TykTechnologies/storage/temporal/model"
"github.com/TykTechnologies/storage/temporal/temperr"
"github.com/stretchr/testify/assert"
)

func TestRedisCluster_SetRollingWindow(t *testing.T) {
connectors := testutil.TestConnectors(t)
defer testutil.CloseConnectors(t, connectors)

tcs := []struct {
name string
keyName string
per int64
valueOverride string
pipeline bool
expectedErr error
expectedLen int
}{
{
name: "valid_rolling_window",
keyName: "key1",
per: 60,
valueOverride: "value1",
pipeline: false,
expectedErr: nil,
expectedLen: 0,
},
{
name: "empty_key_name",
keyName: "",
per: 60,
valueOverride: "value2",
pipeline: false,
expectedErr: temperr.KeyEmpty,
expectedLen: 0,
},
{
name: "negative_period",
keyName: "key2",
per: -10,
valueOverride: "value3",
pipeline: false,
expectedErr: temperr.InvalidPeriod,
expectedLen: 0,
},
{
name: "pipeline_enabled",
keyName: "key_pipeline",
per: 60,
valueOverride: "pipeline_value",
pipeline: true,
expectedErr: nil,
expectedLen: 0,
},
{
name: "valueOverride",
keyName: "key_value_override",
per: 60,
valueOverride: "-1",
pipeline: false,
expectedErr: nil,
expectedLen: 0,
},
}

for _, connector := range connectors {
for _, tc := range tcs {
t.Run(connector.Type()+"_"+tc.name, func(t *testing.T) {
now := time.Now()
ctx := context.Background()

rateLimiter, err := NewRateLimit(connector)
assert.Nil(t, err)

flusher, err := flusher.NewFlusher(connector)
assert.Nil(t, err)
defer assert.Nil(t, flusher.FlushAll(ctx))

result, err := rateLimiter.SetRollingWindow(ctx, now, tc.keyName, tc.per, tc.valueOverride, tc.pipeline)

assert.Equal(t, tc.expectedErr, err)

if err == nil {
assert.Equal(t, tc.expectedLen, len(result))
// Executing SetRollingWindow again should return expectedLen + 1 if err == nil
result, err = rateLimiter.SetRollingWindow(ctx, now, tc.keyName, tc.per, tc.valueOverride, tc.pipeline)
assert.NoError(t, err)
assert.Equal(t, tc.expectedLen+1, len(result))
}
})
}
}
}

func TestRedisCluster_GetRollingWindow(t *testing.T) {
connectors := testutil.TestConnectors(t)
defer testutil.CloseConnectors(t, connectors)

tcs := []struct {
name string
keyName string
per int64
pipeline bool
expectedErr error
expectedLen int
preTest func(ctx context.Context, rateLimiter model.RateLimit)
}{
{
name: "empty_sorted_set",
keyName: "key_empty",
per: 60,
pipeline: false,
expectedErr: nil,
expectedLen: 0,
},
{
name: "non_empty_sorted_set",
keyName: "key_non_empty",
per: 60,
pipeline: false,
expectedErr: nil,
expectedLen: 2,
preTest: func(ctx context.Context, rateLimiter model.RateLimit) {
now := time.Now()
_, err := rateLimiter.SetRollingWindow(ctx, now, "key_non_empty", 60, "value1", false)
assert.Nil(t, err)
_, err = rateLimiter.SetRollingWindow(ctx, now, "key_non_empty", 60, "value2", false)
assert.Nil(t, err)
},
},
{
name: "pipeline_enabled",
keyName: "key_pipeline",
per: 60,
pipeline: true,
expectedErr: nil,
expectedLen: 0,
},
{
name: "negative_period",
keyName: "key_negative_period",
per: -10,
pipeline: false,
expectedErr: temperr.InvalidPeriod,
expectedLen: 0,
},
{
name: "empty_key_name",
keyName: "",
per: 60,
pipeline: false,
expectedErr: temperr.KeyEmpty,
expectedLen: 0,
},
}

for _, connector := range connectors {
for _, tc := range tcs {
t.Run(connector.Type()+"_"+tc.name, func(t *testing.T) {
ctx := context.Background()

rateLimiter, err := NewRateLimit(connector)
assert.Nil(t, err)

if tc.preTest != nil {
tc.preTest(ctx, rateLimiter)
}

result, err := rateLimiter.GetRollingWindow(ctx, time.Now(), tc.keyName, tc.per, tc.pipeline)

assert.Equal(t, tc.expectedErr, err)
if err == nil {
assert.Equal(t, tc.expectedLen, len(result))
}
})
}
}
}
1 change: 1 addition & 0 deletions temporal/temperr/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ var (

// Redis related errors
InvalidRedisClient = errors.New("invalid redis client")
InvalidPeriod = errors.New("invalid period specified")

// TLS related errors
// TLS related errors
Expand Down