diff --git a/go.mod b/go.mod index 6f0e0580..74deb9d4 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,11 @@ require ( gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22 ) +require ( + github.com/dchest/siphash v1.2.2 // indirect + github.com/dustinxie/lockfree v0.0.0-20210712051436-ed0ed42fd0d6 // indirect +) + require ( github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect diff --git a/go.sum b/go.sum index 628f340b..58bbbd88 100644 --- a/go.sum +++ b/go.sum @@ -5,8 +5,12 @@ github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dchest/siphash v1.2.2 h1:9DFz8tQwl9pTVt5iok/9zKyzA1Q6bRGiF3HPiEEVr9I= +github.com/dchest/siphash v1.2.2/go.mod h1:q+IRvb2gOSrUnYoPqHiyHXS0FOBBOdl6tONBlVnOnt4= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dustinxie/lockfree v0.0.0-20210712051436-ed0ed42fd0d6 h1:OCG9DHxQwv2sABVGARZaUh4OK8dVaR3kzTIHV0vW4gg= +github.com/dustinxie/lockfree v0.0.0-20210712051436-ed0ed42fd0d6/go.mod h1:m7oIj8lFrQgKxP9h9m6GxjzGbTuMD5/5yXF8+pTpJms= github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -24,6 +28,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= diff --git a/temporal/connector/connector.go b/temporal/connector/connector.go index 27213540..7d8aea10 100644 --- a/temporal/connector/connector.go +++ b/temporal/connector/connector.go @@ -1,6 +1,7 @@ package connector import ( + "github.com/TykTechnologies/storage/temporal/internal/driver/local" "github.com/TykTechnologies/storage/temporal/internal/driver/redisv9" "github.com/TykTechnologies/storage/temporal/model" "github.com/TykTechnologies/storage/temporal/temperr" @@ -15,6 +16,8 @@ func NewConnector(connType string, options ...model.Option) (model.Connector, er switch connType { case model.RedisV9Type: return redisv9.NewRedisV9WithOpts(options...) + case model.LocalType: + return local.NewLocalConnector(local.NewLockFreeStore()), nil default: return nil, temperr.InvalidHandlerType } diff --git a/temporal/flusher/flusher.go b/temporal/flusher/flusher.go index 8e848eb2..d2667c6e 100644 --- a/temporal/flusher/flusher.go +++ b/temporal/flusher/flusher.go @@ -1,6 +1,7 @@ package flusher import ( + "github.com/TykTechnologies/storage/temporal/internal/driver/local" "github.com/TykTechnologies/storage/temporal/internal/driver/redisv9" "github.com/TykTechnologies/storage/temporal/model" "github.com/TykTechnologies/storage/temporal/temperr" @@ -12,6 +13,8 @@ func NewFlusher(conn model.Connector) (Flusher, error) { switch conn.Type() { case model.RedisV9Type: return redisv9.NewRedisV9WithConnection(conn) + case model.LocalType: + return local.NewLocalStore(conn), nil default: return nil, temperr.InvalidHandlerType } diff --git a/temporal/internal/driver/local/connector.go b/temporal/internal/driver/local/connector.go new file mode 100644 index 00000000..3bd11f10 --- /dev/null +++ b/temporal/internal/driver/local/connector.go @@ -0,0 +1,47 @@ +package local + +import ( + "context" + "sync" + + "github.com/TykTechnologies/storage/temporal/temperr" +) + +type LocalConnector struct { + Store KVStore + Broker Broker + mutex sync.RWMutex + connected bool +} + +// Disconnect disconnects from the backend +func (api *LocalConnector) Disconnect(context.Context) error { + api.mutex.RLock() + defer api.mutex.RUnlock() + api.connected = false + return nil +} + +// Ping executes a ping to the backend +func (api *LocalConnector) Ping(context.Context) error { + if !api.connected { + return temperr.ClosedConnection + } + + return nil +} + +// Type returns the connector type +func (api *LocalConnector) Type() string { + return "local" +} + +// As converts i to driver-specific types. +// Same concept as https://gocloud.dev/concepts/as/ but for connectors. +func (api *LocalConnector) As(i interface{}) bool { + if _, ok := i.(*API); ok { + return true + } + + return false +} diff --git a/temporal/internal/driver/local/flusher.go b/temporal/internal/driver/local/flusher.go new file mode 100644 index 00000000..b843fed2 --- /dev/null +++ b/temporal/internal/driver/local/flusher.go @@ -0,0 +1,34 @@ +package local + +import ( + "context" +) + +func (api *API) FlushAll(ctx context.Context) error { + // save the ops + _, ok := api.Store.Features()[FeatureFlushAll] + if ok { + err := api.Store.FlushAll() + if err != nil { + return err + } + + api.initialiseKeyIndexes() + } + + keyIndex, err := api.Store.Get(keyIndexKey) + if err != nil { + return err + } + + keys := keyIndex.Value.(map[string]interface{}) + for key := range keys { + err := api.Delete(ctx, key) + if err != nil { + return err + } + } + + api.initialiseKeyIndexes() + return nil +} diff --git a/temporal/internal/driver/local/keyvalue.go b/temporal/internal/driver/local/keyvalue.go new file mode 100644 index 00000000..b6ec5c6d --- /dev/null +++ b/temporal/internal/driver/local/keyvalue.go @@ -0,0 +1,413 @@ +package local + +import ( + "context" + "fmt" + "strconv" + "strings" + "time" + + "github.com/TykTechnologies/storage/temporal/temperr" +) + +func (api *API) Get(ctx context.Context, key string) (string, error) { + if key == "" { + return "", temperr.KeyEmpty + } + + o, err := api.Store.Get(key) + if err != nil { + return "", err + } + + if o == nil || o.IsExpired() || o.Deleted { + return "", temperr.KeyNotFound + } + + return api.convertToString(o.Value) +} + +func (api *API) convertToString(value interface{}) (string, error) { + switch v := value.(type) { + case string: + return v, nil + case int: + return strconv.Itoa(v), nil + case int64: + return strconv.FormatInt(v, 10), nil + case int32: + return strconv.FormatInt(int64(v), 10), nil + default: + return "", temperr.KeyMisstype + } +} + +func (api *API) Set(ctx context.Context, key, value string, ttl time.Duration) error { + if key == "" { + return temperr.KeyEmpty + } + + o := &Object{ + Value: value, + NoExp: ttl <= 0, + } + + if !o.NoExp { + o.SetExpire(ttl) + } + + if err := api.Store.Set(key, o); err != nil { + return err + } + + return api.addToKeyIndex(key) +} + +func (api *API) SetIfNotExist(ctx context.Context, key, value string, expiration time.Duration) (bool, error) { + if key == "" { + return false, temperr.KeyEmpty + } + + o, err := api.Store.Get(key) + if err != nil { + return false, err + } + + if o != nil && !o.Deleted && !o.IsExpired() { + return false, nil + } + + if err := api.Set(ctx, key, value, expiration); err != nil { + return false, err + } + + return true, nil +} + +func (api *API) Delete(ctx context.Context, key string) error { + if key == "" { + return temperr.KeyEmpty + } + + o, err := api.Store.Get(key) + if err != nil { + return err + } + + if o == nil { + return nil // Key doesn't exist, no need to delete + } + + // Check if hard delete is supported by the store + _, delSupport := api.Store.Features()[FeatureHardDelete] + if delSupport { + return api.Store.Delete(key) + } + + o.Deleted = true + o.DeletedAt = time.Now() + o.Value = "" // empty the value to save memory + + if err := api.Store.Set(key, o); err != nil { + return err + } + + return api.updateDeletedKeysIndex(key) +} + +func NewCounter(value int64) *Object { + return &Object{ + Value: value, + Type: TypeCounter, + NoExp: true, + } +} + +func (api *API) Increment(ctx context.Context, key string) (int64, error) { + if key == "" { + return 0, temperr.KeyEmpty + } + + o, err := api.Store.Get(key) + if err != nil || o == nil || o.Deleted || o.IsExpired() { + return api.createNewCounter(key) + } + + value, err := api.getCounterValue(o) + if err != nil { + return 0, err + } + + newValue := value + 1 + o.Value = newValue + o.Type = TypeCounter + + if err := api.Store.Set(key, o); err != nil { + return 0, err + } + + return newValue, nil +} + +func (api *API) createNewCounter(key string) (int64, error) { + o := NewCounter(1) + if err := api.Store.Set(key, o); err != nil { + return 0, err + } + if err := api.addToKeyIndex(key); err != nil { + return 0, err + } + return 1, nil +} + +func (api *API) getCounterValue(o *Object) (int64, error) { + switch v := o.Value.(type) { + case int: + return int64(v), nil + case int64: + return v, nil + case int32: + return int64(v), nil + case string: + i, err := strconv.Atoi(v) + if err != nil { + return 0, temperr.KeyMisstype + } + return int64(i), err + default: + return 0, temperr.KeyMisstype + } +} + +func (api *API) Decrement(ctx context.Context, key string) (int64, error) { + if key == "" { + return 0, temperr.KeyEmpty + } + + o, err := api.Store.Get(key) + if err != nil || o == nil || o.Deleted || o.IsExpired() { + return api.createNewCounterWithValue(key, -1) + } + + value, err := api.getCounterValue(o) + if err != nil { + return 0, err + } + + newValue := value - 1 + o.Value = newValue + o.Type = TypeCounter + + if err := api.Store.Set(key, o); err != nil { + return 0, err + } + + return newValue, nil +} + +func (api *API) createNewCounterWithValue(key string, value int64) (int64, error) { + o := NewCounter(value) + if err := api.Store.Set(key, o); err != nil { + return 0, err + } + if err := api.addToKeyIndex(key); err != nil { + return 0, err + } + return value, nil +} + +func (api *API) Exists(ctx context.Context, key string) (bool, error) { + if key == "" { + return false, temperr.KeyEmpty + } + + _, err := api.Get(ctx, key) + if err == nil { + return true, nil + } + if err == temperr.KeyNotFound { + return false, nil + } + return false, err +} + +func (api *API) Expire(ctx context.Context, key string, ttl time.Duration) error { + if key == "" { + return temperr.KeyEmpty + } + + // non-existing keys for these functions should return nil, not errors + o, err := api.Store.Get(key) + if err != nil { + return nil + } + if o == nil { + return nil + } + + if ttl <= 0 { + o.NoExp = true + } else { + o.SetExpire(ttl) + o.NoExp = false + } + + return api.Store.Set(key, o) +} + +func (api *API) TTL(ctx context.Context, key string) (int64, error) { + if key == "" { + return -2, temperr.KeyEmpty + } + + o, err := api.Store.Get(key) + if err != nil { + // bizarre, but should return nil + return -2, nil + } + if o == nil { + return -2, nil + } + + if o.NoExp { + return -1, nil + } + + ttl := time.Until(o.Exp).Round(time.Second).Seconds() + return int64(ttl), nil +} + +func (api *API) DeleteKeys(ctx context.Context, keys []string) (int64, error) { + if len(keys) == 0 { + return 0, temperr.KeyEmpty + } + + var deleted int64 + for _, key := range keys { + exists, err := api.Exists(ctx, key) + if err != nil { + return deleted, err + } + if exists { + if err := api.Delete(ctx, key); err != nil { + return deleted, err + } + deleted++ + } + } + + return deleted, nil +} + +func (api *API) DeleteScanMatch(ctx context.Context, pattern string) (int64, error) { + if err := api.Connector.Ping(ctx); err != nil { + return 0, err + } + + keys, err := api.Keys(ctx, pattern) + if err != nil { + return 0, err + } + + // need to return nil for this function + c, err := api.DeleteKeys(ctx, keys) + if err != nil { + return 0, nil + } + + return c, nil +} + +func (api *API) Keys(ctx context.Context, pattern string) ([]string, error) { + if err := api.Connector.Ping(ctx); err != nil { + return nil, err + } + + keyIndexObj, err := api.Store.Get(keyIndexKey) + if err != nil { + return nil, err + } + if keyIndexObj == nil { + return nil, nil + } + + keyIndex, ok := keyIndexObj.Value.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid key index format") + } + + deletedKeyIndexObj, err := api.Store.Get(deletedKeyIndexKey) + if err != nil { + return nil, err + } + deletedKeys, ok := deletedKeyIndexObj.Value.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid deleted key index format") + } + + pattern = strings.TrimSuffix(pattern, "*") + var matchedKeys []string + + for key := range keyIndex { + if !api.isKeyDeleted(key, deletedKeys) && strings.HasPrefix(key, pattern) { + matchedKeys = append(matchedKeys, key) + } + } + + return matchedKeys, nil +} + +func (api *API) isKeyDeleted(key string, deletedKeys map[string]interface{}) bool { + _, deleted := deletedKeys[key] + return deleted +} + +func (api *API) GetMulti(ctx context.Context, keys []string) ([]interface{}, error) { + var values []interface{} + + for _, key := range keys { + value, err := api.Get(ctx, key) + if err == temperr.KeyNotFound { + values = append(values, nil) + } else if err != nil { + return nil, err + } else { + values = append(values, value) + } + } + + return values, nil +} + +func (api *API) GetKeysAndValuesWithFilter(ctx context.Context, pattern string) (map[string]interface{}, error) { + keys, err := api.Keys(ctx, pattern) + if err != nil { + return nil, err + } + + keysAndValues := make(map[string]interface{}) + + for _, key := range keys { + value, err := api.Get(ctx, key) + if err == nil { + keysAndValues[key] = value + } else if err != temperr.KeyNotFound { + return nil, err + } + } + + return keysAndValues, nil +} + +func (api *API) GetKeysWithOpts(ctx context.Context, searchStr string, cursors map[string]uint64, + count int64) (keys []string, updatedCursor map[string]uint64, continueScan bool, err error) { + + if err := api.Connector.Ping(ctx); err != nil { + return nil, nil, false, err + } + + // TODO: Implement the actual functionality based on your requirements + // This function is currently a no-op and needs to be implemented + + return nil, nil, true, nil +} diff --git a/temporal/internal/driver/local/list.go b/temporal/internal/driver/local/list.go new file mode 100644 index 00000000..834c6eab --- /dev/null +++ b/temporal/internal/driver/local/list.go @@ -0,0 +1,236 @@ +package local + +import ( + "context" + "reflect" + + "github.com/TykTechnologies/storage/temporal/temperr" +) + +func NewListObject(value []string) *Object { + if value == nil { + value = []string{} + } + return &Object{ + Type: TypeList, + Value: value, + NoExp: true, + } +} + +// Remove the first count occurrences of elements equal to element from the list stored at key. If count is 0 remove all elements equal to element. +func (a *API) Remove(ctx context.Context, key string, count int64, iElement interface{}) (int64, error) { + obj, err := a.Store.Get(key) + if err != nil { + return 0, err + } + + if obj == nil || obj.Type != TypeList { + return 0, nil + } + + list := obj.Value.([]string) + var removed int64 + var newList []string + _, ok := iElement.([]byte) + if !ok { + return 0, temperr.KeyMisstype + } + + element := string(iElement.([]byte)) + + if count > 0 { + // Remove from head to tail + for _, item := range list { + if removed < count && reflect.DeepEqual(item, element) { + removed++ + } else { + newList = append(newList, item) + } + } + } else if count < 0 { + // Remove from tail to head + for i := len(list) - 1; i >= 0; i-- { + if removed < -count && reflect.DeepEqual(list[i], element) { + removed++ + } else { + newList = append([]string{list[i]}, newList...) + } + } + } else { // count == 0 + // Remove all occurrences + for _, item := range list { + if !reflect.DeepEqual(item, element) { + newList = append(newList, item) + } + } + removed = int64(len(list) - len(newList)) + } + + if removed > 0 { + obj.Value = newList + err = a.Store.Set(key, obj) + if err != nil { + return 0, err + } + } + + return removed, nil +} + +func (api *API) Range(ctx context.Context, key string, start, stop int64) ([]string, error) { + o, err := api.Store.Get(key) + if err != nil { + return nil, err + } + + if o == nil || o.Type != TypeList { + return nil, nil + } + + list := o.Value.([]string) + length := int64(len(list)) + + // Convert negative indices to positive + if start < 0 { + if start < 0 { + start = 0 + } + } + if stop < 0 { + stop = length + } + + // Ensure from is not greater than length + if start >= length { + return []string{}, nil + } + + // Ensure to is not greater than length + if stop >= length { + stop = length - 1 + } + + // Ensure from is not greater than to + if start > stop { + return []string{}, nil + } + + // +1 because slicing in Go is exclusive for the upper bound + return list[start : stop+1], nil +} + +// Returns the length of the list stored at key. +func (api *API) Length(ctx context.Context, key string) (int64, error) { + o, err := api.Store.Get(key) + if err != nil { + return 0, err + } + + if o == nil || o.Type != TypeList { + return 0, nil + } + + return int64(len(o.Value.([]string))), nil +} + +// Insert all the specified values at the head of the list stored at key. +// If key does not exist, it is created. +// pipelined: If true, the operation is pipelined and executed in a single roundtrip. +func (api *API) Prepend(ctx context.Context, pipelined bool, key string, values ...[]byte) error { + o, err := api.Store.Get(key) + if err != nil { + return err + } + + // reverse the vlaues + for i, j := 0, len(values)-1; i < j; i, j = i+1, j-1 { + values[i], values[j] = values[j], values[i] + } + + if o == nil { + l := make([]string, len(values)) + for i, value := range values { + l[i] = string(value) + } + + o = NewListObject(l) + api.Store.Set(key, o) + return nil + } + + if o.Type != TypeList { + return temperr.KeyMisstype + } + + // values is in order, but needs to be inserted in reverse order + for i := len(values) - 1; i >= 0; i-- { + o.Value = append([]string{string(values[i])}, o.Value.([]string)...) + } + + api.Store.Set(key, o) + return nil +} + +func (api *API) Append(ctx context.Context, pipelined bool, key string, values ...[]byte) error { + o, err := api.Store.Get(key) + if err != nil { + return err + } + + if o == nil { + l := make([]string, len(values)) + for i, value := range values { + l[i] = string(value) + } + o = NewListObject(l) + api.Store.Set(key, o) + return nil + } + + if o.Type != TypeList { + return temperr.KeyMisstype + } + + for _, value := range values { + o.Value = append(o.Value.([]string), string(value)) + } + + api.Store.Set(key, o) + return nil +} + +// Pop removes and returns the first count elements of the list stored at key. +// If stop is -1, all the elements from start to the end of the list are removed and returned. +func (api *API) Pop(ctx context.Context, key string, stop int64) ([]string, error) { + o, err := api.Store.Get(key) + if err != nil { + return nil, err + } + + if o == nil || o.Type != TypeList { + return nil, nil + } + + list := o.Value.([]string) + length := int64(len(list)) + + var incl int64 = 0 + if stop == -1 { + stop = length + incl = 1 + } + + if stop >= length { + stop = length - 1 + } + + if stop < 0 { + return []string{}, nil + } + + popped := list[:stop+incl] + o.Value = list[stop+incl:] + api.Store.Set(key, o) + return popped, nil +} diff --git a/temporal/internal/driver/local/local.go b/temporal/internal/driver/local/local.go new file mode 100644 index 00000000..a1600971 --- /dev/null +++ b/temporal/internal/driver/local/local.go @@ -0,0 +1,72 @@ +package local + +import ( + "github.com/TykTechnologies/storage/temporal/model" +) + +type API struct { + // Store has no inherent features except storing and retrieving data + Store KVStore + Connector model.Connector + Broker Broker +} + +const ( + keyIndexKey = "rumbaba:keyIndex" + deletedKeyIndexKey = "rumbaba:deletedKeyIndex" + + TypeBytes = "bytes" + TypeSet = "set" + TypeSortedSet = "sortedset" + TypeList = "list" + TypeCounter = "counter" +) + +var mockStore *MockStore + +func init() { + mockStore = NewMockStore() +} + +func GetMockStore() *MockStore { + return mockStore +} + +func NewLocalConnector(withStore KVStore) model.Connector { + return &LocalConnector{ + Store: withStore, + connected: true, + Broker: NewMockBroker(), + } +} + +func NewLocalStore(connector model.Connector) *API { + api := &API{ + Connector: connector, + Store: connector.(*LocalConnector).Store, + Broker: connector.(*LocalConnector).Broker, + } + + initAlready, _ := api.Store.Get(keyIndexKey) + if initAlready != nil { + return api + } + + // init the key indexes + api.initialiseKeyIndexes() + return api +} + +func (api *API) initialiseKeyIndexes() { + api.Store.Set(keyIndexKey, &Object{ + Type: TypeList, + Value: map[string]interface{}{}, + NoExp: true, + }) + + api.Store.Set(deletedKeyIndexKey, &Object{ + Type: TypeList, + Value: map[string]interface{}{}, + NoExp: true, + }) +} diff --git a/temporal/internal/driver/local/mockbroker.go b/temporal/internal/driver/local/mockbroker.go new file mode 100644 index 00000000..897093bd --- /dev/null +++ b/temporal/internal/driver/local/mockbroker.go @@ -0,0 +1,148 @@ +package local + +import ( + "context" + "errors" + "sync" + + "github.com/TykTechnologies/storage/temporal/model" +) + +// MockBroker is a mock implementation of the Broker interface +type MockBroker struct { + subscriptions map[string][]chan model.Message + mu sync.RWMutex +} + +// NewMockBroker creates a new MockBroker +func NewMockBroker() *MockBroker { + return &MockBroker{ + subscriptions: make(map[string][]chan model.Message), + } +} + +// Publish sends a message to all subscribers of the specified channel +func (mb *MockBroker) Publish(channel, message string) (int64, error) { + mb.mu.RLock() + defer mb.mu.RUnlock() + + subscribers, ok := mb.subscriptions[channel] + if !ok { + return 0, nil + } + + msg := &MockMessage{ + messageType: "message", + channel: channel, + payload: message, + } + + for _, ch := range subscribers { + select { + case ch <- msg: + default: + // If the channel is full, we skip this subscriber + } + } + + return int64(len(subscribers)), nil +} + +// Subscribe creates a new subscription for the specified channels +func (mb *MockBroker) Subscribe(channels ...string) model.Subscription { + mb.mu.Lock() + defer mb.mu.Unlock() + + msgChan := make(chan model.Message, 100) + sub := &MockSubscription{ + broker: mb, + channels: channels, + msgChan: msgChan, + } + + for _, channel := range channels { + mb.subscriptions[channel] = append(mb.subscriptions[channel], msgChan) + // Send subscription confirmation message + msgChan <- &MockMessage{ + messageType: "subscription", + channel: channel, + payload: "subscribe", + } + } + + return sub +} + +// MockSubscription is a mock implementation of the Subscription interface +type MockSubscription struct { + broker *MockBroker + channels []string + msgChan chan model.Message + closed bool + mu sync.Mutex +} + +// Receive waits for and returns the next message from the subscription +func (ms *MockSubscription) Receive(ctx context.Context) (model.Message, error) { + select { + case msg := <-ms.msgChan: + return msg, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// Close closes the subscription and cleans up resources +func (ms *MockSubscription) Close() error { + ms.mu.Lock() + defer ms.mu.Unlock() + + if ms.closed { + return errors.New("subscription already closed") + } + + ms.closed = true + + ms.broker.mu.Lock() + defer ms.broker.mu.Unlock() + + for _, channel := range ms.channels { + subscribers := ms.broker.subscriptions[channel] + for i, ch := range subscribers { + if ch == ms.msgChan { + ms.broker.subscriptions[channel] = append(subscribers[:i], subscribers[i+1:]...) + break + } + } + } + + return nil +} + +// MockMessage is a mock implementation of the Message interface +type MockMessage struct { + messageType string + channel string + payload string +} + +// Type returns the message type +func (mm *MockMessage) Type() string { + return mm.messageType +} + +// Channel returns the channel the message was received on +func (mm *MockMessage) Channel() (string, error) { + if mm.messageType == "message" || mm.messageType == "subscription" { + return mm.channel, nil + } + return "", errors.New("invalid message type") +} + +// Payload returns the message payload +func (mm *MockMessage) Payload() (string, error) { + if mm.messageType == "message" || mm.messageType == "subscription" { + return mm.payload, nil + } + return "", errors.New("invalid message type") +} diff --git a/temporal/internal/driver/local/mockbroker_test.go b/temporal/internal/driver/local/mockbroker_test.go new file mode 100644 index 00000000..7bc9816d --- /dev/null +++ b/temporal/internal/driver/local/mockbroker_test.go @@ -0,0 +1,186 @@ +package local + +import ( + "context" + "testing" + "time" +) + +func TestNewMockBroker(t *testing.T) { + broker := NewMockBroker() + if broker == nil { + t.Fatal("NewMockBroker() returned nil") + } + if broker.subscriptions == nil { + t.Error("NewMockBroker() did not initialize subscriptions map") + } +} + +func TestMockBroker_Publish(t *testing.T) { + broker := NewMockBroker() + channel := "testChannel" + message := "testMessage" + + // Test publishing to a channel with no subscribers + count, err := broker.Publish(channel, message) + if err != nil { + t.Errorf("Publish() error = %v", err) + } + if count != 0 { + t.Errorf("Publish() to empty channel returned count = %d, want 0", count) + } + + // Add a subscriber + sub := broker.Subscribe(channel) + defer sub.Close() + + // Test publishing to a channel with a subscriber + count, err = broker.Publish(channel, message) + if err != nil { + t.Errorf("Publish() error = %v", err) + } + if count != 1 { + t.Errorf("Publish() returned count = %d, want 1", count) + } + + // Verify the message was received + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + receivedMsg1, err := sub.Receive(ctx) + if err != nil { + t.Errorf("Receive() error = %v", err) + } + if receivedMsg1.Type() != "subscription" { + t.Errorf("Received message type = %s, want 'subscription'", receivedMsg1.Type()) + } + receivedChannel, _ := receivedMsg1.Channel() + if receivedChannel != channel { + t.Errorf("Received message channel = %s, want %s", receivedChannel, channel) + } + receivedPayload, _ := receivedMsg1.Payload() + if receivedPayload != "subscribe" { + t.Errorf("Received message payload = %s, want %s", receivedPayload, "subscribe") + } + + receivedMsg2, err := sub.Receive(ctx) + if err != nil { + t.Errorf("Receive() error = %v", err) + } + if receivedMsg2.Type() != "message" { + t.Errorf("Received message type = %s, want 'message'", receivedMsg2.Type()) + } + receivedChannel2, _ := receivedMsg2.Channel() + if receivedChannel2 != channel { + t.Errorf("Received message channel = %s, want %s", receivedChannel2, channel) + } + receivedPayload2, _ := receivedMsg2.Payload() + if receivedPayload2 != message { + t.Errorf("Received message payload = %s, want %s", receivedPayload2, message) + } +} + +func TestMockBroker_Subscribe(t *testing.T) { + broker := NewMockBroker() + channels := []string{"channel1", "channel2"} + + sub := broker.Subscribe(channels...) + + // Verify subscription confirmation messages + for _, channel := range channels { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + msg, err := sub.Receive(ctx) + cancel() + if err != nil { + t.Errorf("Receive() error = %v", err) + } + if msg.Type() != "subscription" { + t.Errorf("Subscription message type = %s, want 'subscription'", msg.Type()) + } + msgChannel, _ := msg.Channel() + if msgChannel != channel { + t.Errorf("Subscription message channel = %s, want %s", msgChannel, channel) + } + payload, _ := msg.Payload() + if payload != "subscribe" { + t.Errorf("Subscription message payload = %s, want 'subscribe'", payload) + } + } + + // Verify subscriptions were added + broker.mu.RLock() + defer broker.mu.RUnlock() + for _, channel := range channels { + if len(broker.subscriptions[channel]) != 1 { + t.Errorf("Channel %s has %d subscribers, want 1", channel, len(broker.subscriptions[channel])) + } + } +} + +func TestMockSubscription_Close(t *testing.T) { + broker := NewMockBroker() + channel := "testChannel" + sub := broker.Subscribe(channel) + + // Close the subscription + err := sub.Close() + if err != nil { + t.Errorf("Close() error = %v", err) + } + + // Verify the subscription was removed + broker.mu.RLock() + defer broker.mu.RUnlock() + if len(broker.subscriptions[channel]) != 0 { + t.Errorf("Channel %s has %d subscribers after close, want 0", channel, len(broker.subscriptions[channel])) + } + + // Try to close again + err = sub.Close() + if err == nil { + t.Error("Close() on already closed subscription should return an error") + } +} + +func TestMockMessage(t *testing.T) { + tests := []struct { + name string + messageType string + channel string + payload string + wantErr bool + }{ + {"Valid message", "message", "testChannel", "testPayload", false}, + {"Valid subscription", "subscription", "testChannel", "subscribe", false}, + {"Invalid type", "invalid", "testChannel", "testPayload", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + msg := &MockMessage{ + messageType: tt.messageType, + channel: tt.channel, + payload: tt.payload, + } + + if msg.Type() != tt.messageType { + t.Errorf("Type() = %v, want %v", msg.Type(), tt.messageType) + } + + channel, err := msg.Channel() + if (err != nil) != tt.wantErr { + t.Errorf("Channel() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr && channel != tt.channel { + t.Errorf("Channel() = %v, want %v", channel, tt.channel) + } + + payload, err := msg.Payload() + if (err != nil) != tt.wantErr { + t.Errorf("Payload() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr && payload != tt.payload { + t.Errorf("Payload() = %v, want %v", payload, tt.payload) + } + }) + } +} diff --git a/temporal/internal/driver/local/mockstore.go b/temporal/internal/driver/local/mockstore.go new file mode 100644 index 00000000..8fd60c68 --- /dev/null +++ b/temporal/internal/driver/local/mockstore.go @@ -0,0 +1,57 @@ +package local + +import ( + "sync" +) + +// MockStore implements the KVStore interface using a map +type MockStore struct { + data map[string]interface{} + mutex sync.RWMutex +} + +func NewMockStore() *MockStore { + return &MockStore{ + data: make(map[string]interface{}), + } +} + +func (m *MockStore) Get(key string) (*Object, error) { + m.mutex.RLock() + defer m.mutex.RUnlock() + + if obj, ok := m.data[key]; ok { + return obj.(*Object), nil + } + return nil, nil +} + +func (m *MockStore) Set(key string, value interface{}) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + m.data[key] = value + return nil +} + +func (m *MockStore) Delete(key string) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + delete(m.data, key) + return nil +} + +func (m *MockStore) FlushAll() error { + m.mutex.Lock() + defer m.mutex.Unlock() + + m.data = make(map[string]interface{}) + return nil +} + +func (m *MockStore) Features() map[ExtendedFeature]bool { + return map[ExtendedFeature]bool{ + FeatureFlushAll: true, + } +} diff --git a/temporal/internal/driver/local/mockstore_test.go b/temporal/internal/driver/local/mockstore_test.go new file mode 100644 index 00000000..26152140 --- /dev/null +++ b/temporal/internal/driver/local/mockstore_test.go @@ -0,0 +1,130 @@ +package local + +import ( + "fmt" + "reflect" + "sync" + "testing" +) + +func TestNewMockStore(t *testing.T) { + store := NewMockStore() + if store == nil { + t.Fatal("NewMockStore() returned nil") + } + if store.data == nil { + t.Error("NewMockStore() did not initialize data map") + } +} + +func TestMockStore_Get(t *testing.T) { + store := NewMockStore() + obj := &Object{} // Assume Object is defined elsewhere + store.data["testKey"] = obj + + tests := []struct { + name string + key string + want *Object + wantErr bool + }{ + {"Existing key", "testKey", obj, false}, + {"Non-existing key", "nonExistingKey", nil, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := store.Get(tt.key) + if (err != nil) != tt.wantErr { + t.Errorf("MockStore.Get() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MockStore.Get() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMockStore_Set(t *testing.T) { + store := NewMockStore() + obj := &Object{} + + if err := store.Set("testKey", obj); err != nil { + t.Errorf("MockStore.Set() error = %v", err) + } + + if store.data["testKey"] != obj { + t.Errorf("MockStore.Set() did not store the object correctly") + } +} + +func TestMockStore_Delete(t *testing.T) { + store := NewMockStore() + store.data["testKey"] = &Object{} + + if err := store.Delete("testKey"); err != nil { + t.Errorf("MockStore.Delete() error = %v", err) + } + + if _, exists := store.data["testKey"]; exists { + t.Errorf("MockStore.Delete() did not remove the key") + } +} + +func TestMockStore_FlushAll(t *testing.T) { + store := NewMockStore() + store.data["testKey"] = &Object{} + + if err := store.FlushAll(); err != nil { + t.Errorf("MockStore.FlushAll() error = %v", err) + } + + if len(store.data) != 0 { + t.Errorf("MockStore.FlushAll() did not remove all keys") + } +} + +func TestMockStore_Features(t *testing.T) { + store := NewMockStore() + features := store.Features() + + if !features[FeatureFlushAll] { + t.Errorf("MockStore.Features() FeatureFlushAll should be true") + } +} + +func TestMockStore_Concurrency(t *testing.T) { + store := NewMockStore() + const goroutines = 100 + var wg sync.WaitGroup + wg.Add(goroutines * 4) // 4 operations per goroutine + + for i := 0; i < goroutines; i++ { + go func(i int) { + defer wg.Done() + key := fmt.Sprintf("key%d", i) + store.Set(key, &Object{}) + }(i) + + go func(i int) { + defer wg.Done() + key := fmt.Sprintf("key%d", i) + store.Get(key) + }(i) + + go func(i int) { + defer wg.Done() + key := fmt.Sprintf("key%d", i) + store.Delete(key) + }(i) + + go func() { + defer wg.Done() + store.FlushAll() + }() + } + + wg.Wait() + // If we reach here without deadlocks or race conditions, the test passes +} diff --git a/temporal/internal/driver/local/nonblocking_store.go b/temporal/internal/driver/local/nonblocking_store.go new file mode 100644 index 00000000..0f45fddb --- /dev/null +++ b/temporal/internal/driver/local/nonblocking_store.go @@ -0,0 +1,69 @@ +package local + +import "github.com/dustinxie/lockfree" + +type LockFreeStore struct { + store lockfree.HashMap + Broker Broker +} + +func NewLockFreeStore() *LockFreeStore { + return &LockFreeStore{ + store: lockfree.NewHashMap(), + Broker: NewMockBroker(), + } +} + +func (m *LockFreeStore) Get(key string) (*Object, error) { + v, ok := m.store.Get(key) + if !ok { + return nil, nil + } + + if v == nil { + return nil, nil + } + + return v.(*Object), nil +} + +func (m *LockFreeStore) Set(key string, value interface{}) error { + m.store.Set(key, value) + return nil +} + +func (m *LockFreeStore) Delete(key string) error { + m.store.Del(key) + return nil +} + +func (m *LockFreeStore) FlushAll() error { + delList := []interface{}{} + f := func(k interface{}, v interface{}) error { + delList = append(delList, k) + return nil + } + + m.store.Lock() + for k, v, ok := m.store.Next(); ok; k, v, ok = m.store.Next() { + if err := f(k, v); err != nil { + // unlock the map before return, otherwise it will deadlock + m.store.Unlock() + return err + } + } + m.store.Unlock() + + for _, k := range delList { + m.store.Del(k) + } + + return nil +} + +func (m *LockFreeStore) Features() map[ExtendedFeature]bool { + return map[ExtendedFeature]bool{ + FeatureFlushAll: true, + FeatureHardDelete: true, + } +} diff --git a/temporal/internal/driver/local/nonblocking_store_test.go b/temporal/internal/driver/local/nonblocking_store_test.go new file mode 100644 index 00000000..613faaf4 --- /dev/null +++ b/temporal/internal/driver/local/nonblocking_store_test.go @@ -0,0 +1,165 @@ +package local + +import ( + "fmt" + "reflect" + "sync" + "testing" +) + +func TestNewLockFreeStore(t *testing.T) { + store := NewLockFreeStore() + if store == nil { + t.Fatal("NewLockFreeStore() returned nil") + } + if store.store == nil { + t.Error("NewLockFreeStore() did not initialize store") + } + if store.Broker == nil { + t.Error("NewLockFreeStore() did not initialize Broker") + } +} + +func TestLockFreeStore_Get(t *testing.T) { + store := NewLockFreeStore() + obj := &Object{} // Assume Object is defined elsewhere + store.store.Set("testKey", obj) + + tests := []struct { + name string + key string + want *Object + wantErr bool + }{ + {"Existing key", "testKey", obj, false}, + {"Non-existing key", "nonExistingKey", nil, false}, + {"Nil value", "nilKey", nil, false}, + } + + store.store.Set("nilKey", nil) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := store.Get(tt.key) + if (err != nil) != tt.wantErr { + t.Errorf("LockFreeStore.Get() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("LockFreeStore.Get() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestLockFreeStore_Set(t *testing.T) { + store := NewLockFreeStore() + obj := &Object{} + + if err := store.Set("testKey", obj); err != nil { + t.Errorf("LockFreeStore.Set() error = %v", err) + } + + if v, ok := store.store.Get("testKey"); !ok || v != obj { + t.Errorf("LockFreeStore.Set() did not store the object correctly") + } +} + +func TestLockFreeStore_Delete(t *testing.T) { + store := NewLockFreeStore() + store.store.Set("testKey", &Object{}) + + if err := store.Delete("testKey"); err != nil { + t.Errorf("LockFreeStore.Delete() error = %v", err) + } + + if _, ok := store.store.Get("testKey"); ok { + t.Errorf("LockFreeStore.Delete() did not remove the key") + } +} + +func TestLockFreeStore_FlushAll(t *testing.T) { + store := NewLockFreeStore() + + // Add some test data + testKeys := []string{"testKey1", "testKey2", "testKey3"} + for _, key := range testKeys { + store.Set(key, &Object{}) + } + + // Verify data is present + for _, key := range testKeys { + if _, ok := store.store.Get(key); !ok { + t.Errorf("Test setup failed: key %s not found before FlushAll", key) + } + } + + // Perform FlushAll + if err := store.FlushAll(); err != nil { + t.Errorf("LockFreeStore.FlushAll() error = %v", err) + } + + // Verify all data has been removed + for _, key := range testKeys { + if value, ok := store.store.Get(key); ok { + t.Errorf("LockFreeStore.FlushAll() did not remove key %s, value: %v", key, value) + } + } + + // Additional check: try to add and retrieve a new key-value pair + newKey := "newTestKey" + newValue := &Object{} + store.Set(newKey, newValue) + + retrievedValue, err := store.Get(newKey) + if err != nil { + t.Errorf("Error retrieving new key after FlushAll: %v", err) + } + if retrievedValue != newValue { + t.Errorf("Retrieved value does not match set value after FlushAll") + } +} + +func TestLockFreeStore_Features(t *testing.T) { + store := NewLockFreeStore() + features := store.Features() + + if !features[FeatureFlushAll] { + t.Errorf("LockFreeStore.Features() FeatureFlushAll should be true") + } +} + +func TestLockFreeStore_Concurrency(t *testing.T) { + store := NewLockFreeStore() + const goroutines = 1000 + var wg sync.WaitGroup + wg.Add(goroutines * 4) // 4 operations per goroutine + + for i := 0; i < goroutines; i++ { + go func(i int) { + defer wg.Done() + key := fmt.Sprintf("key%d", i) + store.Set(key, &Object{}) + }(i) + + go func(i int) { + defer wg.Done() + key := fmt.Sprintf("key%d", i) + store.Get(key) + }(i) + + go func(i int) { + defer wg.Done() + key := fmt.Sprintf("key%d", i) + store.Delete(key) + }(i) + + go func() { + defer wg.Done() + store.FlushAll() + }() + } + + wg.Wait() + // If we reach here without deadlocks or race conditions, the test passes +} diff --git a/temporal/internal/driver/local/queue.go b/temporal/internal/driver/local/queue.go new file mode 100644 index 00000000..087f89cc --- /dev/null +++ b/temporal/internal/driver/local/queue.go @@ -0,0 +1,34 @@ +package local + +import ( + "context" + + "github.com/TykTechnologies/storage/temporal/model" +) + +// ===== QUEUE ===== + +// Publish sends a message to the specified channel. +// It returns the number of clients that received the message. +func (api *API) Publish(ctx context.Context, channel, message string) (int64, error) { + // We're ignoring the context here as the Broker interface doesn't use it. + // In a real implementation, you might want to respect context cancellation. + err := api.Connector.Ping(ctx) + if err != nil { + return 0, err + } + + return api.Broker.Publish(channel, message) +} + +// Subscribe initializes a subscription to one or more channels. +// It returns a Subscription interface that allows receiving messages and closing the subscription. +func (api *API) Subscribe(ctx context.Context, channels ...string) model.Subscription { + // We're ignoring the context here as the Broker interface doesn't use it. + // In a real implementation, you might want to respect context cancellation. + err := api.Connector.Ping(ctx) + if err != nil { + return nil + } + return api.Broker.Subscribe(channels...) +} diff --git a/temporal/internal/driver/local/set.go b/temporal/internal/driver/local/set.go new file mode 100644 index 00000000..2a2979ca --- /dev/null +++ b/temporal/internal/driver/local/set.go @@ -0,0 +1,150 @@ +package local + +import ( + "context" + "errors" + + "github.com/TykTechnologies/storage/temporal/temperr" +) + +func NewSetObject(value []interface{}) *Object { + return &Object{ + Type: TypeSet, + Value: value, + NoExp: true, + } +} + +func (api *API) Members(ctx context.Context, key string) ([]string, error) { + if key == "" { + return nil, temperr.KeyEmpty + } + + o, err := api.Store.Get(key) + if err != nil { + return nil, err + } + + if o == nil { + return []string{}, nil + } + + if o.Deleted || o.IsExpired() { + return []string{}, nil + } + + if o.Type != TypeSet { + return nil, errors.New("key not a valid set") + } + + set := make([]string, 0) + if o.Value == nil { + return set, nil + } + + for _, v := range o.Value.([]interface{}) { + vs, ok := v.(string) + if !ok { + return nil, errors.New("invalid set member") + } + + set = append(set, vs) + } + + return set, nil +} + +func (api *API) AddMember(ctx context.Context, key, member string) error { + if key == "" { + return temperr.KeyEmpty + } + + o, err := api.Store.Get(key) + if err != nil { + o = NewSetObject([]interface{}{member}) + api.Store.Set(key, o) + } + + if o == nil { + o = NewSetObject([]interface{}{member}) + api.Store.Set(key, o) + return nil + } + + if o.Type != TypeSet || o.Deleted || o.IsExpired() { + return temperr.KeyNotFound + } + + o.Value = append(o.Value.([]interface{}), member) + + err = api.Store.Set(key, o) + if err != nil { + return err + } + + return nil +} + +func (api *API) RemoveMember(ctx context.Context, key, member string) error { + if key == "" { + return temperr.KeyEmpty + } + + o, err := api.Store.Get(key) + if err != nil { + return err + } + + if o == nil { + return nil + } + + if o.Type != TypeSet || o.Deleted || o.IsExpired() { + return errors.New("key not a valid set") + } + + var newSet []interface{} + for _, v := range o.Value.([]interface{}) { + if v == member { + continue + } + + newSet = append(newSet, v) + } + + o.Value = newSet + + err = api.Store.Set(key, o) + if err != nil { + return err + } + + return nil +} + +func (api *API) IsMember(ctx context.Context, key, member string) (bool, error) { + if key == "" { + return false, temperr.KeyEmpty + } + + o, err := api.Store.Get(key) + if err != nil { + return false, err + } + + if o == nil { + return false, nil + } + + if o.Type != TypeSet || o.Deleted || o.IsExpired() { + return false, errors.New("key not a valid set") + } + + for _, v := range o.Value.([]interface{}) { + if v == member { + return true, nil + } + } + + return false, nil +} diff --git a/temporal/internal/driver/local/sortedset.go b/temporal/internal/driver/local/sortedset.go new file mode 100644 index 00000000..89e06ce7 --- /dev/null +++ b/temporal/internal/driver/local/sortedset.go @@ -0,0 +1,195 @@ +package local + +import ( + "context" + "fmt" + "math" + "sort" + "strconv" + "strings" + + "github.com/TykTechnologies/storage/temporal/temperr" +) + +type SortedSetEntry struct { + Score float64 + Member string +} + +type SortedSet []SortedSetEntry + +func (s SortedSet) Len() int { return len(s) } +func (s SortedSet) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s SortedSet) Less(i, j int) bool { + if s[i].Score == s[j].Score { + return s[i].Member < s[j].Member + } + return s[i].Score < s[j].Score +} + +func NewSortedSetObject() *Object { + return &Object{ + Type: TypeSortedSet, + Value: make(SortedSet, 0), + NoExp: true, + } +} + +func (api *API) AddScoredMember(ctx context.Context, key, member string, score float64) (int64, error) { + var added int64 + o, err := api.Store.Get(key) + if err != nil { + o = NewSortedSetObject() + } + + if o == nil { + o = NewSortedSetObject() + } + + if o.Type != TypeSortedSet { + return 0, fmt.Errorf("key is not a sorted set") + } + + sortedSet := o.Value.(SortedSet) + + index := -1 + for j, entry := range sortedSet { + if entry.Member == member { + index = j + break + } + } + + if index == -1 { + sortedSet = append(sortedSet, SortedSetEntry{Score: score, Member: member}) + added++ + } else { + sortedSet[index].Score = score + } + + sort.Sort(sortedSet) + o.Value = sortedSet + api.Store.Set(key, o) + return added, nil +} + +func (api *API) GetMembersByScoreRange(ctx context.Context, key, minScore, maxScore string) ([]interface{}, []float64, error) { + o, err := api.Store.Get(key) + if err != nil { + return nil, nil, err + } + + if o == nil { + return []interface{}{}, []float64{}, nil + } + + if o.Deleted || o.IsExpired() { + return []interface{}{}, []float64{}, nil + } + + if o.Type != TypeSortedSet { + return nil, nil, temperr.KeyMisstype + } + + sortedSet := o.Value.(SortedSet) + + from, fromInclusive, err := parseScore(minScore) + if err != nil { + return []interface{}{}, []float64{}, err + } + + to, toInclusive, err := parseScore(maxScore) + if err != nil { + return []interface{}{}, []float64{}, err + } + + var members = make([]interface{}, 0) + var scores = make([]float64, 0) + + for _, entry := range sortedSet { + if (fromInclusive && entry.Score >= from || !fromInclusive && entry.Score > from) && + (toInclusive && entry.Score <= to || !toInclusive && entry.Score < to) { + members = append(members, entry.Member) + scores = append(scores, entry.Score) + } + if entry.Score > to { + break + } + } + + return members, scores, nil +} + +func (api *API) RemoveMembersByScoreRange(ctx context.Context, key, minScore, maxScore string) (int64, error) { + o, err := api.Store.Get(key) + var removed int64 + if err != nil { + return 0, err + } + + if o == nil { + return 0, nil + } + + if o.Type != TypeSortedSet { + return 0, temperr.KeyMisstype + } + + if o.Deleted || o.IsExpired() { + return 0, nil + } + + sortedSet := o.Value.(SortedSet) + + from, fromInclusive, err := parseScore(minScore) + if err != nil { + return 0, err + } + + to, toInclusive, err := parseScore(maxScore) + if err != nil { + return 0, err + } + + var newSet SortedSet + + for _, entry := range sortedSet { + if (fromInclusive && entry.Score >= from || !fromInclusive && entry.Score > from) && + (toInclusive && entry.Score <= to || !toInclusive && entry.Score < to) { + // Skip this entry (effectively removing it) + removed++ + continue + } + newSet = append(newSet, entry) + } + + o.Value = newSet + err = api.Store.Set(key, o) + if err != nil { + return 0, err + } + + return removed, nil +} + +func parseScore(score string) (float64, bool, error) { + inclusive := true + if strings.HasPrefix(score, "(") { + inclusive = false + score = score[1:] + } + + if score == "-inf" { + return math.Inf(-1), inclusive, nil + } + if score == "+inf" { + return math.Inf(1), inclusive, nil + } + + value, err := strconv.ParseFloat(score, 64) + if err != nil { + return 0, false, fmt.Errorf("invalid score: %s", score) + } + + return value, inclusive, nil +} diff --git a/temporal/internal/driver/local/types.go b/temporal/internal/driver/local/types.go new file mode 100644 index 00000000..ccc822b9 --- /dev/null +++ b/temporal/internal/driver/local/types.go @@ -0,0 +1,54 @@ +package local + +import ( + "time" + + "github.com/TykTechnologies/storage/temporal/model" +) + +type Object struct { + Exp time.Time + Type string + NoExp bool + Deleted bool + DeletedAt time.Time + Value interface{} +} + +func (o *Object) IsExpired() bool { + if o.NoExp { + return false + } + + return time.Now().After(o.Exp) +} + +func (o *Object) SetExpire(d time.Duration) { + o.Exp = time.Now().Add(d) + o.NoExp = false +} + +func (o *Object) SetExpireAt(t time.Time) { + o.Exp = t + o.NoExp = false +} + +type ExtendedFeature string + +const ( + FeatureFlushAll ExtendedFeature = "flushall" + FeatureHardDelete ExtendedFeature = "harddelete" +) + +type KVStore interface { + Get(key string) (*Object, error) + Set(key string, value interface{}) error + Delete(key string) error + FlushAll() error + Features() map[ExtendedFeature]bool +} + +type Broker interface { + Publish(channel, message string) (int64, error) + Subscribe(channels ...string) model.Subscription +} diff --git a/temporal/internal/driver/local/util.go b/temporal/internal/driver/local/util.go new file mode 100644 index 00000000..2c9a3fb5 --- /dev/null +++ b/temporal/internal/driver/local/util.go @@ -0,0 +1,63 @@ +package local + +func (api *API) addToKeyIndex(key string) error { + o, err := api.Store.Get(keyIndexKey) + if err != nil { + // not found, create new + o = &Object{ + Type: TypeSet, + Value: map[string]interface{}{key: true}, + NoExp: true, + } + + return api.Store.Set(keyIndexKey, o) + } + + if o == nil { + o = &Object{ + Type: TypeSet, + Value: map[string]interface{}{key: true}, + NoExp: true, + } + } + + list := o.Value.(map[string]interface{}) + list[key] = true + + o.Value = list + + err = api.Store.Set(keyIndexKey, o) + if err != nil { + return err + } + + return nil +} + +func (api *API) updateDeletedKeysIndex(key string) error { + o, err := api.Store.Get(deletedKeyIndexKey) + if err != nil { + // not found, create new + o = &Object{ + Type: TypeSet, + Value: map[string]interface{}{key: true}, + NoExp: true, + } + + return api.Store.Set(deletedKeyIndexKey, o) + } + + if o == nil { + o = &Object{ + Type: TypeSet, + Value: map[string]interface{}{key: true}, + NoExp: true, + } + } + + list := o.Value.(map[string]interface{}) + list[key] = true + o.Value = list + + return api.Store.Set(deletedKeyIndexKey, o) +} diff --git a/temporal/internal/driver/local/util_test.go b/temporal/internal/driver/local/util_test.go new file mode 100644 index 00000000..59782c69 --- /dev/null +++ b/temporal/internal/driver/local/util_test.go @@ -0,0 +1,119 @@ +package local + +import ( + "testing" +) + +func TestAPI_addToKeyIndex(t *testing.T) { + tests := []struct { + name string + initialData map[string]*Object + key string + wantErr bool + }{ + { + name: "New key index", + initialData: map[string]*Object{}, + key: "testKey", + wantErr: false, + }, + { + name: "Existing key index", + initialData: map[string]*Object{ + keyIndexKey: { + Type: TypeSet, + Value: map[string]interface{}{"existingKey": true}, + NoExp: true, + }, + }, + key: "newKey", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockStore := NewMockStore() + for k, v := range tt.initialData { + mockStore.Set(k, v) + } + + api := &API{Store: mockStore} + err := api.addToKeyIndex(tt.key) + + if (err != nil) != tt.wantErr { + t.Errorf("API.addToKeyIndex() error = %v, wantErr %v", err, tt.wantErr) + return + } + + // Verify the key was added + obj, _ := mockStore.Get(keyIndexKey) + if obj == nil { + t.Errorf("Key index object not found") + return + } + + keySet := obj.Value.(map[string]interface{}) + if _, exists := keySet[tt.key]; !exists { + t.Errorf("Key %s not found in key index", tt.key) + } + }) + } +} + +func TestAPI_updateDeletedKeysIndex(t *testing.T) { + tests := []struct { + name string + initialData map[string]*Object + key string + wantErr bool + }{ + { + name: "New deleted key index", + initialData: map[string]*Object{}, + key: "deletedKey", + wantErr: false, + }, + { + name: "Existing deleted key index", + initialData: map[string]*Object{ + deletedKeyIndexKey: { + Type: TypeSet, + Value: map[string]interface{}{"existingDeletedKey": true}, + NoExp: true, + }, + }, + key: "newDeletedKey", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockStore := NewMockStore() + for k, v := range tt.initialData { + mockStore.Set(k, v) + } + + api := &API{Store: mockStore} + err := api.updateDeletedKeysIndex(tt.key) + + if (err != nil) != tt.wantErr { + t.Errorf("API.updateDeletedKeysIndex() error = %v, wantErr %v", err, tt.wantErr) + return + } + + // Verify the key was added to the deleted keys index + obj, _ := mockStore.Get(deletedKeyIndexKey) + if obj == nil { + t.Errorf("Deleted key index object not found") + return + } + + deletedKeySet := obj.Value.(map[string]interface{}) + if _, exists := deletedKeySet[tt.key]; !exists { + t.Errorf("Key %s not found in deleted key index", tt.key) + } + }) + } +} diff --git a/temporal/internal/testutil/testutil.go b/temporal/internal/testutil/testutil.go index 39444638..05dcd514 100644 --- a/temporal/internal/testutil/testutil.go +++ b/temporal/internal/testutil/testutil.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/TykTechnologies/storage/temporal/connector" + "github.com/TykTechnologies/storage/temporal/internal/driver/local" "github.com/TykTechnologies/storage/temporal/model" "github.com/stretchr/testify/assert" ) @@ -42,9 +43,12 @@ func TestConnectors(t *testing.T) []model.Connector { // redisv9 list redisConnector := newRedisConnector(t) - connectors = append(connectors, redisConnector) + // local non-blocking hashmap + localConnector := local.NewLocalConnector(local.NewLockFreeStore()) + connectors = append(connectors, localConnector) + return connectors } diff --git a/temporal/keyvalue/keyvalue.go b/temporal/keyvalue/keyvalue.go index 1870a016..99237ecb 100644 --- a/temporal/keyvalue/keyvalue.go +++ b/temporal/keyvalue/keyvalue.go @@ -1,6 +1,7 @@ package temporal import ( + "github.com/TykTechnologies/storage/temporal/internal/driver/local" "github.com/TykTechnologies/storage/temporal/internal/driver/redisv9" "github.com/TykTechnologies/storage/temporal/model" "github.com/TykTechnologies/storage/temporal/temperr" @@ -15,6 +16,8 @@ func NewKeyValue(conn model.Connector) (KeyValue, error) { switch conn.Type() { case model.RedisV9Type: return redisv9.NewRedisV9WithConnection(conn) + case model.LocalType: + return local.NewLocalStore(conn), nil default: return nil, temperr.InvalidHandlerType } diff --git a/temporal/keyvalue/keyvalue_test.go b/temporal/keyvalue/keyvalue_test.go index 0e2c7d3a..63259ebb 100644 --- a/temporal/keyvalue/keyvalue_test.go +++ b/temporal/keyvalue/keyvalue_test.go @@ -100,7 +100,7 @@ func TestKeyValue_Get(t *testing.T) { }{ { name: "non_existing_key", - key: "key1", + key: "key", expectedValue: "", expectedErr: temperr.KeyNotFound, }, @@ -238,6 +238,20 @@ func TestKeyValue_Increment(t *testing.T) { expectedValue: 1, expectedErr: nil, }, + { + name: "multi_increment_existing_key", + key: "counter", + setup: func(db KeyValue) { + for i := 0; i < 5; i++ { + _, err := db.Increment(context.Background(), "counter") + if err != nil { + t.Fatalf("Set() error = %v", err) + } + } + }, + expectedValue: 6, + expectedErr: nil, + }, { name: "empty_key", key: "", @@ -311,6 +325,25 @@ func TestKeyValue_Decrement(t *testing.T) { expectedValue: -1, expectedErr: nil, }, + { + name: "multi_decr_existing_key", + setup: func(db KeyValue) { + err := db.Set(context.Background(), "counter", "10", 0) + if err != nil { + t.Fatalf("Set() error = %v", err) + } + + for i := 0; i < 5; i++ { + _, err := db.Decrement(context.Background(), "counter") + if err != nil { + t.Fatalf("Set() error = %v", err) + } + } + }, + key: "counter", + expectedValue: 4, + expectedErr: nil, + }, { name: "empty_key", key: "", @@ -1220,6 +1253,10 @@ func TestKeyValue_GetKeysWithOpts(t *testing.T) { } for _, connector := range connectors { + if connector.Type() == "local" { + // local connector does not support SCAN + continue + } for _, tc := range tcs { t.Run(connector.Type()+"_"+tc.name, func(t *testing.T) { ctx := context.Background() diff --git a/temporal/list/list.go b/temporal/list/list.go index 36f1c6e8..a8440467 100644 --- a/temporal/list/list.go +++ b/temporal/list/list.go @@ -1,6 +1,7 @@ package list import ( + "github.com/TykTechnologies/storage/temporal/internal/driver/local" "github.com/TykTechnologies/storage/temporal/internal/driver/redisv9" "github.com/TykTechnologies/storage/temporal/model" "github.com/TykTechnologies/storage/temporal/temperr" @@ -14,6 +15,8 @@ func NewList(conn model.Connector) (List, error) { switch conn.Type() { case model.RedisV9Type: return redisv9.NewRedisV9WithConnection(conn) + case model.LocalType: + return local.NewLocalStore(conn), nil default: return nil, temperr.InvalidHandlerType } diff --git a/temporal/model/types.go b/temporal/model/types.go index 95697508..ee04390a 100644 --- a/temporal/model/types.go +++ b/temporal/model/types.go @@ -7,6 +7,7 @@ import ( const ( RedisV9Type = "redisv9" + LocalType = "local" ) type Connector interface { diff --git a/temporal/queue/queue.go b/temporal/queue/queue.go index 3535d926..6b0b7b5c 100644 --- a/temporal/queue/queue.go +++ b/temporal/queue/queue.go @@ -1,6 +1,7 @@ package queue import ( + "github.com/TykTechnologies/storage/temporal/internal/driver/local" "github.com/TykTechnologies/storage/temporal/internal/driver/redisv9" "github.com/TykTechnologies/storage/temporal/model" "github.com/TykTechnologies/storage/temporal/temperr" @@ -14,6 +15,8 @@ func NewQueue(conn model.Connector) (Queue, error) { switch conn.Type() { case model.RedisV9Type: return redisv9.NewRedisV9WithConnection(conn) + case model.LocalType: + return local.NewLocalStore(conn), nil default: return nil, temperr.InvalidHandlerType } diff --git a/temporal/set/set.go b/temporal/set/set.go index cd74734a..1d2c6b49 100644 --- a/temporal/set/set.go +++ b/temporal/set/set.go @@ -1,6 +1,7 @@ package set import ( + "github.com/TykTechnologies/storage/temporal/internal/driver/local" "github.com/TykTechnologies/storage/temporal/internal/driver/redisv9" "github.com/TykTechnologies/storage/temporal/model" "github.com/TykTechnologies/storage/temporal/temperr" @@ -14,6 +15,8 @@ func NewSet(conn model.Connector) (Set, error) { switch conn.Type() { case model.RedisV9Type: return redisv9.NewRedisV9WithConnection(conn) + case model.LocalType: + return local.NewLocalStore(conn), nil default: return nil, temperr.InvalidHandlerType } diff --git a/temporal/sortedset/sortedset.go b/temporal/sortedset/sortedset.go index ebf77584..726de473 100644 --- a/temporal/sortedset/sortedset.go +++ b/temporal/sortedset/sortedset.go @@ -1,6 +1,7 @@ package sortedset import ( + "github.com/TykTechnologies/storage/temporal/internal/driver/local" "github.com/TykTechnologies/storage/temporal/internal/driver/redisv9" "github.com/TykTechnologies/storage/temporal/model" "github.com/TykTechnologies/storage/temporal/temperr" @@ -14,6 +15,8 @@ func NewSortedSet(conn model.Connector) (SortedSet, error) { switch conn.Type() { case model.RedisV9Type: return redisv9.NewRedisV9WithConnection(conn) + case model.LocalType: + return local.NewLocalStore(conn), nil default: return nil, temperr.InvalidHandlerType }