Skip to content

Commit

Permalink
[TT-11091] Fixing Keys method to fetch all the keys when scanning (#110)
Browse files Browse the repository at this point in the history
* Fixing Keys method to fetch all the keys when scanning

* renaming old function to avoid confussion
  • Loading branch information
mativm02 authored Feb 6, 2024
1 parent d2a325e commit 24ad76c
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 26 deletions.
41 changes: 27 additions & 14 deletions temporal/internal/driver/redisv9/keyvalue.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ func (r *RedisV9) Keys(ctx context.Context, pattern string) ([]string, error) {
switch client := r.client.(type) {
case *redis.ClusterClient:
err := client.ForEachMaster(ctx, func(ctx context.Context, client *redis.Client) error {
keys, _, err := fetchKeys(ctx, client, pattern, 0, 0)
keys, err := fetchAllKeys(ctx, client, pattern)
if err != nil {
if firstError == nil {
firstError = err
Expand All @@ -269,7 +269,7 @@ func (r *RedisV9) Keys(ctx context.Context, pattern string) ([]string, error) {
}

case *redis.Client:
keys, _, err := fetchKeys(ctx, client, pattern, 0, 0)
keys, err := fetchAllKeys(ctx, client, pattern)
if err != nil {
if errors.Is(err, redis.ErrClosed) {
return nil, temperr.ClosedConnection
Expand All @@ -279,7 +279,6 @@ func (r *RedisV9) Keys(ctx context.Context, pattern string) ([]string, error) {
}

sessions = keys

default:
return nil, temperr.InvalidRedisClient
}
Expand Down Expand Up @@ -412,7 +411,7 @@ func (r *RedisV9) GetKeysWithOpts(ctx context.Context,
return nil
}

localKeys, fkCursor, err := fetchKeys(ctx, client, searchStr, cursor[client.String()], count)
localKeys, fkCursor, err := fetchKeysWithCursor(ctx, client, searchStr, cursor[client.String()], count)
if err != nil {
return err
}
Expand All @@ -437,7 +436,7 @@ func (r *RedisV9) GetKeysWithOpts(ctx context.Context,
}

case *redis.Client:
localKeys, fkCursor, err := fetchKeys(ctx, client, searchStr, cursor[client.String()], int64(count))
localKeys, fkCursor, err := fetchKeysWithCursor(ctx, client, searchStr, cursor[client.String()], int64(count))
if err != nil {
if errors.Is(err, redis.ErrClosed) {
return localKeys, cursor, continueScan, temperr.ClosedConnection
Expand All @@ -460,7 +459,20 @@ func (r *RedisV9) GetKeysWithOpts(ctx context.Context,
return keys, cursor, continueScan, nil
}

func fetchKeys(ctx context.Context,
func (r *RedisV9) SetIfNotExist(ctx context.Context, key, value string, expiration time.Duration) (bool, error) {
if key == "" {
return false, temperr.KeyEmpty
}

res := r.client.SetNX(ctx, key, value, expiration)
if res.Err() != nil {
return false, res.Err()
}

return res.Val(), nil
}

func fetchKeysWithCursor(ctx context.Context,
client redis.UniversalClient,
pattern string,
cursor uint64,
Expand All @@ -478,15 +490,16 @@ func fetchKeys(ctx context.Context,
return keys, cursor, nil
}

func (r *RedisV9) SetIfNotExist(ctx context.Context, key, value string, expiration time.Duration) (bool, error) {
if key == "" {
return false, temperr.KeyEmpty
}
func fetchAllKeys(ctx context.Context,
client redis.UniversalClient,
pattern string,
) ([]string, error) {
iter := client.Scan(ctx, 0, pattern, 0).Iterator()
var keys []string

res := r.client.SetNX(ctx, key, value, expiration)
if res.Err() != nil {
return false, res.Err()
for iter.Next(ctx) {
keys = append(keys, iter.Val())
}

return res.Val(), nil
return keys, iter.Err()
}
21 changes: 9 additions & 12 deletions temporal/keyvalue/keyvalue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,8 @@ func TestKeyValue_Keys(t *testing.T) {
connectors := testutil.TestConnectors(t)
defer testutil.CloseConnectors(t, connectors)

fiftyExpectedKeys := make([]string, 50)

tcs := []struct {
name string
setup func(db KeyValue)
Expand All @@ -729,21 +731,16 @@ func TestKeyValue_Keys(t *testing.T) {
{
name: "existing_keys_pattern",
setup: func(db KeyValue) {
err := db.Set(context.Background(), "key1", "value1", 0)
if err != nil {
t.Fatalf("Set() error = %v", err)
}
err = db.Set(context.Background(), "key2", "value2", 0)
if err != nil {
t.Fatalf("Set() error = %v", err)
}
err = db.Set(context.Background(), "test", "value2", 0)
if err != nil {
t.Fatalf("Set() error = %v", err)
for i := 1; i <= 50; i++ {
fiftyExpectedKeys[i-1] = fmt.Sprintf("key%d", i)
err := db.Set(context.Background(), fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 0)
if err != nil {
t.Fatalf("Set() error = %v", err)
}
}
},
expectedKeys: fiftyExpectedKeys,
pattern: "key*",
expectedKeys: []string{"key1", "key2"},
expectedErr: nil,
},
{
Expand Down

0 comments on commit 24ad76c

Please sign in to comment.