diff --git a/temporal/internal/driver/redisv8/ratelimiter.go b/temporal/internal/driver/redisv8/ratelimiter.go new file mode 100644 index 00000000..eac3873f --- /dev/null +++ b/temporal/internal/driver/redisv8/ratelimiter.go @@ -0,0 +1,94 @@ +package redisv8 + +import ( + "context" + "strconv" + "time" + + "github.com/TykTechnologies/storage/temporal/temperr" + "github.com/go-redis/redis/v8" +) + +// SetRollingWindow updates a sorted set in Redis to represent a rolling time window of values. +func (r *RedisV8) SetRollingWindow(ctx context.Context, now time.Time, keyName string, per int64, valueOverride string, pipeline bool) ([]string, error) { + if keyName == "" { + return []string{}, temperr.KeyEmpty + } + if per <= 0 { + return []string{}, temperr.InvalidPeriod + } + + onePeriodAgo := now.Add(time.Duration(-1*per) * time.Second) + expire := time.Duration(per) * time.Second + + memberValue := valueOverride + if valueOverride == "-1" { + memberValue = strconv.Itoa(int(now.UnixNano())) + } + element := redis.Z{ + Score: float64(now.UnixNano()), + Member: memberValue, + } + + var zrange *redis.StringSliceCmd + var err error + + exec := r.client.TxPipelined + if pipeline { + exec = r.client.Pipelined + } + + pipeFn := func(pipe redis.Pipeliner) error { + // removing elements outside the rolling window. + pipe.ZRemRangeByScore(ctx, keyName, "-inf", strconv.Itoa(int(onePeriodAgo.UnixNano()))) + // getting the current range of values within the window. + zrange = pipe.ZRange(ctx, keyName, 0, -1) + // adding the new element and set the expiration time. + pipe.ZAdd(ctx, keyName, &element) + pipe.Expire(ctx, keyName, expire) + return nil + } + + _, err = exec(ctx, pipeFn) + if err != nil { + return nil, err + } + + return zrange.Result() +} + +// GetRollingWindow removes a part of a sorted set in Redis and extracts a timed window of values. +func (r *RedisV8) GetRollingWindow(ctx context.Context, now time.Time, keyName string, per int64, pipeline bool) ([]string, error) { + if keyName == "" { + return []string{}, temperr.KeyEmpty + } + if per <= 0 { + return []string{}, temperr.InvalidPeriod + } + + onePeriodAgo := now.Add(time.Duration(-1*per) * time.Second) + period := strconv.FormatInt(onePeriodAgo.UnixNano(), 10) + + var zrange *redis.StringSliceCmd + var err error + + exec := r.client.TxPipelined + if pipeline { + exec = r.client.Pipelined + } + + pipeFn := func(pipe redis.Pipeliner) error { + // removing old elements outside the rolling window + pipe.ZRemRangeByScore(ctx, keyName, "-inf", period) + // retrieving the current range of values + zrange = pipe.ZRange(ctx, keyName, 0, -1) + return nil + } + + _, err = exec(ctx, pipeFn) + if err != nil { + return nil, err + } + + return zrange.Result() +} diff --git a/temporal/model/types.go b/temporal/model/types.go index 1a0d38b3..5e2b3a7c 100644 --- a/temporal/model/types.go +++ b/temporal/model/types.go @@ -160,3 +160,11 @@ type Message interface { // - an empty string, returning an error Payload() (string, error) } + +type RateLimit interface { + // SetRollingWindow sets the rolling window for a key with a per second rate limit + SetRollingWindow(ctx context.Context, now time.Time, keyName string, per int64, + valueOverride string, pipeline bool) ([]string, error) + // GetRollingWindow gets the rolling window for a key with a per second rate limit + GetRollingWindow(ctx context.Context, now time.Time, keyName string, per int64, pipeline bool) ([]string, error) +} diff --git a/temporal/ratelimiter/ratelimiter.go b/temporal/ratelimiter/ratelimiter.go new file mode 100644 index 00000000..99a3d53c --- /dev/null +++ b/temporal/ratelimiter/ratelimiter.go @@ -0,0 +1,20 @@ +package ratelimiter + +import ( + "github.com/TykTechnologies/storage/temporal/internal/driver/redisv8" + "github.com/TykTechnologies/storage/temporal/model" + "github.com/TykTechnologies/storage/temporal/temperr" +) + +type RateLimit = model.RateLimit + +var _ RateLimit = (*redisv8.RedisV8)(nil) + +func NewRateLimit(conn model.Connector) (RateLimit, error) { + switch conn.Type() { + case model.RedisV8Type: + return redisv8.NewRedisV8WithConnection(conn) + default: + return nil, temperr.InvalidHandlerType + } +} diff --git a/temporal/ratelimiter/ratelimiter_test.go b/temporal/ratelimiter/ratelimiter_test.go new file mode 100644 index 00000000..88ff1e17 --- /dev/null +++ b/temporal/ratelimiter/ratelimiter_test.go @@ -0,0 +1,187 @@ +package ratelimiter + +import ( + "context" + "testing" + "time" + + "github.com/TykTechnologies/storage/temporal/flusher" + "github.com/TykTechnologies/storage/temporal/internal/testutil" + "github.com/TykTechnologies/storage/temporal/model" + "github.com/TykTechnologies/storage/temporal/temperr" + "github.com/stretchr/testify/assert" +) + +func TestRedisCluster_SetRollingWindow(t *testing.T) { + connectors := testutil.TestConnectors(t) + defer testutil.CloseConnectors(t, connectors) + + tcs := []struct { + name string + keyName string + per int64 + valueOverride string + pipeline bool + expectedErr error + expectedLen int + }{ + { + name: "valid_rolling_window", + keyName: "key1", + per: 60, + valueOverride: "value1", + pipeline: false, + expectedErr: nil, + expectedLen: 0, + }, + { + name: "empty_key_name", + keyName: "", + per: 60, + valueOverride: "value2", + pipeline: false, + expectedErr: temperr.KeyEmpty, + expectedLen: 0, + }, + { + name: "negative_period", + keyName: "key2", + per: -10, + valueOverride: "value3", + pipeline: false, + expectedErr: temperr.InvalidPeriod, + expectedLen: 0, + }, + { + name: "pipeline_enabled", + keyName: "key_pipeline", + per: 60, + valueOverride: "pipeline_value", + pipeline: true, + expectedErr: nil, + expectedLen: 0, + }, + { + name: "valueOverride", + keyName: "key_value_override", + per: 60, + valueOverride: "-1", + pipeline: false, + expectedErr: nil, + expectedLen: 0, + }, + } + + for _, connector := range connectors { + for _, tc := range tcs { + t.Run(connector.Type()+"_"+tc.name, func(t *testing.T) { + now := time.Now() + ctx := context.Background() + + rateLimiter, err := NewRateLimit(connector) + assert.Nil(t, err) + + flusher, err := flusher.NewFlusher(connector) + assert.Nil(t, err) + defer assert.Nil(t, flusher.FlushAll(ctx)) + + result, err := rateLimiter.SetRollingWindow(ctx, now, tc.keyName, tc.per, tc.valueOverride, tc.pipeline) + + assert.Equal(t, tc.expectedErr, err) + + if err == nil { + assert.Equal(t, tc.expectedLen, len(result)) + // Executing SetRollingWindow again should return expectedLen + 1 if err == nil + result, err = rateLimiter.SetRollingWindow(ctx, now, tc.keyName, tc.per, tc.valueOverride, tc.pipeline) + assert.NoError(t, err) + assert.Equal(t, tc.expectedLen+1, len(result)) + } + }) + } + } +} + +func TestRedisCluster_GetRollingWindow(t *testing.T) { + connectors := testutil.TestConnectors(t) + defer testutil.CloseConnectors(t, connectors) + + tcs := []struct { + name string + keyName string + per int64 + pipeline bool + expectedErr error + expectedLen int + preTest func(ctx context.Context, rateLimiter model.RateLimit) + }{ + { + name: "empty_sorted_set", + keyName: "key_empty", + per: 60, + pipeline: false, + expectedErr: nil, + expectedLen: 0, + }, + { + name: "non_empty_sorted_set", + keyName: "key_non_empty", + per: 60, + pipeline: false, + expectedErr: nil, + expectedLen: 2, + preTest: func(ctx context.Context, rateLimiter model.RateLimit) { + now := time.Now() + _, err := rateLimiter.SetRollingWindow(ctx, now, "key_non_empty", 60, "value1", false) + assert.Nil(t, err) + _, err = rateLimiter.SetRollingWindow(ctx, now, "key_non_empty", 60, "value2", false) + assert.Nil(t, err) + }, + }, + { + name: "pipeline_enabled", + keyName: "key_pipeline", + per: 60, + pipeline: true, + expectedErr: nil, + expectedLen: 0, + }, + { + name: "negative_period", + keyName: "key_negative_period", + per: -10, + pipeline: false, + expectedErr: temperr.InvalidPeriod, + expectedLen: 0, + }, + { + name: "empty_key_name", + keyName: "", + per: 60, + pipeline: false, + expectedErr: temperr.KeyEmpty, + expectedLen: 0, + }, + } + + for _, connector := range connectors { + for _, tc := range tcs { + t.Run(connector.Type()+"_"+tc.name, func(t *testing.T) { + ctx := context.Background() + + rateLimiter, err := NewRateLimit(connector) + assert.Nil(t, err) + + if tc.preTest != nil { + tc.preTest(ctx, rateLimiter) + } + + result, err := rateLimiter.GetRollingWindow(ctx, time.Now(), tc.keyName, tc.per, tc.pipeline) + + assert.Equal(t, tc.expectedErr, err) + if err == nil { + assert.Equal(t, tc.expectedLen, len(result)) + } + }) + } + } +} diff --git a/temporal/temperr/errors.go b/temporal/temperr/errors.go index 91fb02e4..c685c3cb 100644 --- a/temporal/temperr/errors.go +++ b/temporal/temperr/errors.go @@ -17,6 +17,7 @@ var ( // Redis related errors InvalidRedisClient = errors.New("invalid redis client") + InvalidPeriod = errors.New("invalid period specified") // TLS related errors // TLS related errors