diff --git a/lib/encoding/messagepack.go b/lib/encoding/messagepack.go index ed8eccf3..46771491 100644 --- a/lib/encoding/messagepack.go +++ b/lib/encoding/messagepack.go @@ -2,31 +2,189 @@ package encoding import ( "bytes" + "errors" "github.com/hashicorp/go-msgpack/codec" + "reflect" ) -// EncodeMessagePack is a function that encodes a given message using MessagePack serialization. -// It takes an interface{} parameter representing the message and returns the encoded byte slice and an error. -func EncodeMessagePack(msg interface{}) ([]byte, error) { - var b []byte - var mph codec.MsgpackHandle - h := &mph - enc := codec.NewEncoderBytes(&b, h) // Create a new encoder with the provided message and MessagePack handle +// MessagePackCodec struct, holds references to MessagePack handler and byte slice, +// along with Encoder and Decoder, and a typeMap for storing reflect.Type +type MessagePackCodec struct { + MsgPack *codec.MsgpackHandle + b *[]byte + enc *codec.Encoder + dec *codec.Decoder +} + +// MessagePackCodecEncoder struct derives from MessagePackCodec +// it manages IDs and counts of the encoded objects. +type MessagePackCodecEncoder struct { + MessagePackCodec // Embedded MessagePackCodec +} + +// MessagePackCodecDecoder struct, holds a reference to a MessagePackCodec instance. +type MessagePackCodecDecoder struct { + *MessagePackCodec +} - err := enc.Encode(msg) // Encode the message using the encoder +// NewMsgPackHandle is a helper function to create a new instance of MsgpackHandle +func NewMsgPackHandle() *codec.MsgpackHandle { + return &codec.MsgpackHandle{} +} + +// InitMessagePack function initializes MessagePackCodec struct and returns it. +func InitMessagePack() MessagePackCodec { + return MessagePackCodec{ + MsgPack: NewMsgPackHandle(), + } +} + +// NewMessagePackEncoder function creates new MessagePackCodecEncoder and initializes it. +func NewMessagePackEncoder() *MessagePackCodecEncoder { + msgPack := NewMsgPackHandle() + b := make([]byte, 0) + return &MessagePackCodecEncoder{ + MessagePackCodec: MessagePackCodec{ + MsgPack: msgPack, + b: &b, + enc: codec.NewEncoderBytes(&b, msgPack), + }, + } +} + +// NewMessagePackDecoder function takes in a byte slice, and returns a pointer to newly created +// and initialized MessagePackCodecDecoder +func NewMessagePackDecoder(b []byte) *MessagePackCodecDecoder { + msgPack := NewMsgPackHandle() + return &MessagePackCodecDecoder{ + MessagePackCodec: &MessagePackCodec{ + MsgPack: msgPack, + b: &b, + dec: codec.NewDecoderBytes(b, msgPack), + }, + } +} + +// Encode method for MessagePackCodec. It encodes the input value into a byte slice using MessagePack. +// Returns encoded byte slice or error. +func (m *MessagePackCodec) Encode(msg interface{}) ([]byte, error) { + var b []byte + err := codec.NewEncoderBytes(&b, m.MsgPack).Encode(msg) if err != nil { return nil, err } - return b, nil // Return the encoded byte slice + return b, nil +} + +// Encode is a method for MessagePackCodecEncoder. +// It takes in msg of type interface{} as input, that is to be encoded. +// Returns an error if encountered during encoding. +func (m *MessagePackCodecEncoder) Encode(msg interface{}) error { + return m.enc.Encode(msg) +} + +// Bytes is a method for MessagePackCodecEncoder. +// It returns a byte slice pointer b. +func (m *MessagePackCodecEncoder) Bytes() []byte { + return *m.b +} + +// Decode is a method on MessagePackCodecDecoder that decodes MessagePack data +// into the provided interface; returns an error if any decoding issues occur. +func (m *MessagePackCodecDecoder) Decode(msg interface{}) error { + if m.dec == nil { + return errors.New("decoder not initialized") + } + return m.dec.Decode(msg) +} + +// Decode on MessagePackCodec type, using a byte slice as input. +func (m *MessagePackCodec) Decode(in []byte, out interface{}) error { + // Create new decoder using the byte slice and MessagePack handle. + dec := codec.NewDecoderBytes(in, m.MsgPack) + + // Attempt to decode the byte slice into the desired output structure. + return dec.Decode(out) +} + +// AddExtension method allows for setting custom encoders/decoders for specific reflect.Types. +func (m *MessagePackCodec) AddExtension( + t reflect.Type, + id byte, + encoder func(reflect.Value) ([]byte, error), + decoder func(reflect.Value, []byte) error) error { + + return m.MsgPack.AddExt(t, id, encoder, decoder) +} + +// EncodeMessagePack function encodes a given object into MessagePack format. +func EncodeMessagePack(msg interface{}) ([]byte, error) { + // Directly initialize the byte slice and encoder. + b := make([]byte, 0) + enc := codec.NewEncoderBytes(&b, NewMsgPackHandle()) + + // Attempt to encode the message. + if err := enc.Encode(msg); err != nil { + return nil, err + } + + // Return the encoded byte slice. + return b, nil } -// DecodeMessagePack is a function that decodes a given byte slice using MessagePack deserialization. -// It takes an input byte slice and an interface{} representing the output structure for the deserialized message. -// It returns an error if the decoding process fails. +// DecodeMessagePack function decodes a byte slice of MessagePack data into a given object. func DecodeMessagePack(in []byte, out interface{}) error { - buf := bytes.NewBuffer(in) // Create a new buffer with the input byte slice - mph := codec.MsgpackHandle{} - dev := codec.NewDecoder(buf, &mph) // Create a new decoder with the buffer and MessagePack handle + dec := codec.NewDecoder(bytes.NewBuffer(in), NewMsgPackHandle()) + return dec.Decode(out) +} + +// EncodeString Functions for encoding and decoding strings to and from byte slices. +func EncodeString(s string) ([]byte, error) { + // Check if string length is within correct bounds. + if len(s) > 0x7F { + return nil, errors.New("invalid string length") + } + + // Create a byte slice of appropriate length. + b := make([]byte, len(s)+1) + b[0] = byte(len(s)) + + // Copy the string into the byte slice. + copy(b[1:], s) + + // Return the byte slice. + return b, nil +} + +/* +DecodeString converts an input byte slice into a string. +Arguments: + + b: Input byte slice to be decoded. + +Returns: +- Integer: Length of byte representation of the string. +- String: Decoded string. +- Error: 'invalid length' if slice is empty or insufficient length, else nil. + +The function reads the first byte as string length (l), creates a slice of length l and returns the formed string. +*/ +func DecodeString(b []byte) (int, string, error) { + // Check that byte slice is not empty. + if len(b) == 0 { + return 0, "", errors.New("invalid length") + } + + // Determine the length of the string. + l := int(b[0]) + if len(b) < (l + 1) { + return 0, "", errors.New("invalid length") + } + + // Create a byte slice of the appropriate length and copy the string into it. + s := make([]byte, l) + copy(s, b[1:l+1]) - return dev.Decode(out) // Decode the byte slice into the provided output structure + // Return the length of the string and the string itself. + return l + 1, string(s), nil } diff --git a/lib/encoding/messagepack_test.go b/lib/encoding/messagepack_test.go index 68b12adf..0f987fbf 100644 --- a/lib/encoding/messagepack_test.go +++ b/lib/encoding/messagepack_test.go @@ -1,8 +1,11 @@ package encoding import ( + "bytes" + "github.com/hashicorp/go-msgpack/codec" "github.com/hashicorp/raft" "github.com/stretchr/testify/assert" + "reflect" "testing" ) @@ -16,3 +19,195 @@ func TestEncodeMessagePack(t *testing.T) { assert.NoError(t, err) assert.Equal(t, decData, *data) } +func TestInitMessagePack(t *testing.T) { + msgPack := InitMessagePack() + + assert.NotNil(t, msgPack.MsgPack) +} + +func TestNewMessagePackEncoder(t *testing.T) { + encoder := NewMessagePackEncoder() + + assert.NotNil(t, encoder.enc) + assert.NotNil(t, encoder.b) +} + +func TestNewMessagePackDecoder(t *testing.T) { + exampleBytes := []byte("example") + decoder := NewMessagePackDecoder(exampleBytes) + + assert.NotNil(t, decoder.b) + assert.NotNil(t, decoder.dec) +} + +func TestEncode(t *testing.T) { + msg := "example message" + msgPack := InitMessagePack() + + encoded, err := msgPack.Encode(msg) + + assert.NotNil(t, encoded) + assert.Nil(t, err) +} + +func TestMsgPackDecoder_Decode(t *testing.T) { + msg := "example message" + msgPack := InitMessagePack() + encoded, _ := msgPack.Encode(msg) + decoder := NewMessagePackDecoder(encoded) + + var decoded string + err := decoder.Decode(&decoded) + + assert.Nil(t, err) + assert.Equal(t, msg, decoded) +} + +func TestMsgPackDecoder_Decode_ErrDecoderNotInitialized(t *testing.T) { + msgPack := InitMessagePack() + encoded, _ := msgPack.Encode("example message") + decoder := &MessagePackCodecDecoder{ + MessagePackCodec: &MessagePackCodec{ + MsgPack: NewMsgPackHandle(), + b: &encoded, + }, + } + + var decoded string + err := decoder.Decode(&decoded) + + assert.NotNil(t, err) + assert.Equal(t, "decoder not initialized", err.Error()) +} + +func TestEncodeString_ErrStringLength(t *testing.T) { + s := bytes.Repeat([]byte("a"), 0x80) // 128-byte long string + + encoded, err := EncodeString(string(s)) + + assert.Nil(t, encoded) + assert.NotNil(t, err) + assert.Equal(t, "invalid string length", err.Error()) +} + +func TestMessagePackCodec_Encode(t *testing.T) { + type TestStruct struct { + Field1 string + Field2 int + } + assert := assert.New(t) + + codec := &MessagePackCodec{ + MsgPack: NewMsgPackHandle(), + } + + t.Run("successful encoding", func(t *testing.T) { + testStruct := &TestStruct{ + Field1: "Test", + Field2: 1, + } + outStruct := &TestStruct{} + + b, err := codec.Encode(testStruct) + assert.Nil(err, "Error should be nil") + err = codec.Decode(b, outStruct) + assert.NoError(err) + assert.EqualValues(testStruct, outStruct) + }) + +} +func TestEncodeString(t *testing.T) { + tests := []struct { + input string + expected []byte + hasErr bool + }{ + {"hello", []byte{0x05, 'h', 'e', 'l', 'l', 'o'}, false}, + {"world", []byte{0x05, 'w', 'o', 'r', 'l', 'd'}, false}, + {string(make([]byte, 0x80)), nil, true}, + } + + for _, tt := range tests { + result, err := EncodeString(tt.input) + if (err != nil) != tt.hasErr { + t.Errorf("EncodeString(%q) error = %v, wantErr %v", tt.input, err, tt.hasErr) + continue + } + if !tt.hasErr && string(result) != string(tt.expected) { + t.Errorf("EncodeString(%q) = %q, want %q", tt.input, result, tt.expected) + } + } +} + +func TestDecodeString(t *testing.T) { + tests := []struct { + input []byte + expectedL int + expected string + hasErr bool + }{ + {[]byte{0x05, 'h', 'e', 'l', 'l', 'o'}, 6, "hello", false}, + {[]byte{0x05, 'w', 'o', 'r', 'l', 'd'}, 6, "world", false}, + {[]byte{0x05, 'w', 'o', 'r'}, 0, "", true}, + {[]byte{}, 0, "", true}, + } + + for _, tt := range tests { + length, result, err := DecodeString(tt.input) + if (err != nil) != tt.hasErr { + t.Errorf("DecodeString(%v) error = %v, wantErr %v", tt.input, err, tt.hasErr) + continue + } + if !tt.hasErr && (result != tt.expected || length != tt.expectedL) { + t.Errorf("DecodeString(%v) = %v,%q, want %v,%q", tt.input, length, result, tt.expectedL, tt.expected) + } + } +} +func TestMessagePackCodecEncoder_Encode(t *testing.T) { + var mh codec.MsgpackHandle + mh.MapType = reflect.TypeOf(map[int]int{}) + + encoder := NewMessagePackEncoder() + + err := encoder.Encode(map[int]int{1: 2}) + assert.NoError(t, err) + + err = encoder.Encode(map[int]int{3: 4}) + assert.NoError(t, err) +} + +func TestAddExtension(t *testing.T) { + + type CustomType struct { + Name string + } + + // global extention info. + const extensionID byte = 1 + + encoder := func(rv reflect.Value) ([]byte, error) { + ct := rv.Interface().(CustomType) + return []byte(ct.Name), nil + } + decoder := func(rv reflect.Value, b []byte) error { + rv.Set(reflect.ValueOf(CustomType{Name: string(b)})) + return nil + } + + m := NewMessagePackEncoder() + err := m.AddExtension(reflect.TypeOf(CustomType{}), extensionID, encoder, decoder) + if err != nil { + t.Fatalf("Failed adding extension: %v", err) + } + data := CustomType{Name: "test"} + dataVerify := CustomType{} + err = m.enc.Encode(&data) + assert.NoError(t, err) + assert.NotNil(t, m.Bytes()) + d := NewMessagePackDecoder(m.Bytes()) + err = d.AddExtension(reflect.TypeOf(CustomType{}), extensionID, encoder, decoder) + assert.NoError(t, err) + err = d.Decode(&dataVerify) + assert.NoError(t, err) + assert.EqualValues(t, dataVerify, data) +} diff --git a/structure/zset.go b/structure/zset.go new file mode 100644 index 00000000..04b03469 --- /dev/null +++ b/structure/zset.go @@ -0,0 +1,1235 @@ +package structure + +import ( + "errors" + "fmt" + "github.com/ByteStorage/FlyDB/config" + "github.com/ByteStorage/FlyDB/engine" + _const "github.com/ByteStorage/FlyDB/lib/const" + "github.com/ByteStorage/FlyDB/lib/encoding" + "math" + "math/rand" + "time" +) + +const ( + // SKIPLIST_MAX_LEVEL is better to be log(n) for the best performance. + SKIPLIST_MAX_LEVEL = 10 // + SKIPLIST_PROB = 0.25 // SkipList Probability +) + +/** +ZSet or Sorted Set structure is borrowed from Redis' implementation, the Redis implementation +utilizes a SkipList and a dictionary. +*/ + +// ZSetStructure is a structure for ZSet or SortedSet +type ZSetStructure struct { + db *engine.DB +} + +// FZSet represents a specific data structure in the database, which is key to handling sorted sets (ZSets). +// This struct facilitates interactions with data stored in the sorted set, allowing for both complex and simple operations. +// +// It contains three struct fields: +// +// - 'dict': A Go map with string keys and pointers to ZSetValue values. This map aims to provide quick access to +// individual values in the sorted set based on the provided key. +// +// - 'size': An integer value representing the current size (number of elements) in the FZSet struct. This information is efficiently +// kept track of whenever elements are added or removed from the set, so no separate computation is needed to retrieve this information. +// +// - 'skipList': A pointer towards a SkipList struct. SkipLists perform well under numerous operations, such as insertion, deletion, and searching. They are +// a crucial component in maintaining the sorted set in a practical manner. In this context, the SkipList is used to keep an ordered track of the elements +// in the FZSet struct. +type FZSet struct { + // dict field is a map where the key is a string and + // the value is a pointer to ZSetValue instances, + // codified with the tag "dict". + dict map[string]*ZSetValue `codec:"dict"` + + // size field represents the quantity of elements within + // the structure, codified with the tag "size". + size int `codec:"size"` + + // skipList field is a pointer to an object of type SkipList, + // codified with the tag "skip_list". + skipList *SkipList `codec:"skip_list"` +} + +// SkipList represents a skip list data structure, an ordered list with a hierarchical +// structure that allows for fast search and insertion of elements. +type SkipList struct { + // level represents the highest level of the skip list. + level int + + // head refers to the first node in the skip list. + head *SkipListNode + + // tail refers to the last node in the skip list. + tail *SkipListNode + + // size represents the total number of nodes in the skip list (excluding head and tail nodes). + size int +} + +// SkipListLevel is a structure encapsulating a single level in a skip list data structure. +// It contains two struct fields: +// - 'next': A pointer to the next SkipListNode in the current level. +// - 'span': An integer representing the span size of this SkipListLevel. The span is the number of nodes between the current node +// and the node to which the next pointer is pointing in the skip list. +type SkipListLevel struct { + next *SkipListNode + span int +} + +// SkipListNode represents a single node in a SkipList structure. +// It is built with three elements: +// - 'prev': This is a pointer to the previous node in the skip list. Together with the 'next' pointers in the SkipListNodeLevel, +// it forms a network of nodes, where traversal of the skip list is possible both forwards and backwards. +// - 'level': This is an array (slice) of pointers towards SkipListLevel structures. Each element corresponds to a level of the skip list, +// embedding the 'next' node at that same level, and the span between the current node and that 'next' node. +// - 'value': This is a pointer towards a single ZSetValue structure. It holds the actual payload of the node +// (namely the 'score', 'key', and 'value' properties used in the context of Redis Sorted Sets), as well as provides the basis for ordering of nodes in the skip list. +type SkipListNode struct { + // prev is a pointer to the previous node in the skip list. + prev *SkipListNode + + // level is a slice of pointers to SkipListLevel. + // Each level represents a forward pointer to the next node in the current list level. + level []*SkipListLevel + + // value is a pointer to the ZSetValue. + // This represents the value that this node holds. + value *ZSetValue +} + +// ZSetValue is a struct used in the SkipList data structure. In the context of Redis Sorted Set (ZSet) implementation, +// it represents a single node value in the skip list. A ZSetValue has three members: +// - 'score' which is an integer representing the score of the node. Nodes in a skip list are ordered by this score in ascending order. +// - 'member' which is a string defining the key of the node. For nodes with equal scores, order is determined with lexicographical comparison of keys. +// - 'value' which is an interface{}, meaning it can hold any data type. This represents the actual value of the node in the skip list. +type ZSetValue struct { + // Score is typically used for sorting purposes. Nodes with higher scores will be placed higher in the skip list. + score int + + // member represents the unique identifier for each node. + member string + + // value is the actual content/data that is being stored in the node. + value interface{} +} + +// randomLevel is a function that generates a probabilistic level for a node in a SkipList data structure. +// The goal is to diversify the level distribution and contribute to achieving an ideal skiplist performance. +// Function has no parameters. +// The process starts with two initial variables: +// - 'level' which starts from 1, +// - 'thresh' which is a product of the constant skiplist probability 'SKIPLIST_PROB' and bitwise mask: 0xFFF, taken to the nearest integer. +// +// In an infinite loop, a random 31-bit integer value is generated, bitwise-and is computed with 0xFFF and compared with 'thresh'. +// If the result is smaller, 'level' is incremented by one. Otherwise, the loop is exited. +// Finally, the function checks the calculated level against the maximum allowed skiplist level 'SKIPLIST_MAX_LEVEL'. +// If 'level' is greater, 'SKIPLIST_MAX_LEVEL' is returned, otherwise the calculated 'level' value is returned. +// The function returns an integer which will be the level of new node in skiplist. +func randomLevel() int { + // Initialize level to 1 + level := 1 + r := rand.New(rand.NewSource(time.Now().UnixNano())) + // Calculate the threshold for level. It's derived from the probability constant of the skip list. + thresh := int(math.Round(SKIPLIST_PROB * 0xFFF)) + + // While a randomly generated number is less than this threshold, increment the level. + for int(r.Int31()&0xFFF) < thresh { + level++ + } + + // Check if the level is more than the maximum allowed level for the skip list + // If it is, return the maximum level. Otherwise, return the generated level. + if level > SKIPLIST_MAX_LEVEL { + return SKIPLIST_MAX_LEVEL + } else { + return level + } +} + +// NewZSetStructure Returns a new ZSetStructure +func NewZSetStructure(options config.Options) (*ZSetStructure, error) { + db, err := engine.NewDB(options) + if err != nil { + return nil, err + } + return &ZSetStructure{db: db}, nil +} + +// newZSetNodes is a function that creates a new FZSet object and returns a pointer to it. +// It initializes the dictionary member dict of the newly created object to an empty map. +// The map is intended to map strings to pointers of ZSetValue objects. +// size member of the object is set to 0, indicating that the FZSet object is currently empty. +// The skipList member of the object is set to a new SkipList object created by calling `newSkipList()` function. +func newZSetNodes() *FZSet { + return &FZSet{ + dict: make(map[string]*ZSetValue), + size: 0, + skipList: newSkipList(), + } +} + +// newSkipList is a function that creates an instance of a SkipList struct object and returns a pointer to it. +// This involves initializing the level of the SkipList to 1 and creating a new SkipListNode object as the head of the list. +// The head node is constructed with a level set to SKIPLIST_MAX_LEVEL, key and value as empty string and value as nil respectively. +func newSkipList() *SkipList { + return &SkipList{ + level: 1, + head: newSkipListNode(SKIPLIST_MAX_LEVEL, 0, "", nil), + } +} + +// newSkipListNode is a function that takes integer as level, score and a string as key along with a value of any type. +// It returns a pointer to a SkipListNode. This function is responsible for creating a new SkipListNode with provided level, score, +// key, and value. After creating the node, it initializes every level of the node with an empty SkipListLevel object. +// In the context of a skip list data structure, this function serves as a helper function for creating new nodes to be inserted to the list. +func newSkipListNode(level int, score int, key string, value interface{}) *SkipListNode { + // Create a new SkipListNode with specified score, key, value and a slice of + // SkipListLevel with length equal to specified level + node := &SkipListNode{ + value: newSkipListNodeValue(score, key, value), + level: make([]*SkipListLevel, level), + } + + // Initialize each SkipListLevel in the level slice + for i := range node.level { + node.level[i] = new(SkipListLevel) + } + // Returning the pointer to the created node + return node +} + +// newSkipListNodeValue is a function that constructs and returns a new ZSetValue. +// It takes a score (int), a key (string), and a value (interface{}) as parameters. +// These parameters serve as the initial state of the ZSetValue upon its creation. +func newSkipListNodeValue(score int, member string, value interface{}) *ZSetValue { + // Create a new instance of a ZSetValue with the provided score, key, and value. + node := &ZSetValue{ + score: score, + member: member, + value: value, + } + + // Return the newly created ZSetValue. + return node +} + +// insert is a method of the SkipList type that is used to insert a new node into the skip list. It takes as arguments +// the score (int), key (string) and a value (interface{}), and returns a pointer to the ZSetValue struct. The method +// organizes nodes in the list based on the score in ascending order. If two nodes have the same score, they will be arranged +// based on the key value. The method also assigns span values to the levels in the skip list. +func (sl *SkipList) insert(score int, key string, value interface{}) *ZSetValue { + update := make([]*SkipListNode, SKIPLIST_MAX_LEVEL) + rank := make([]int, SKIPLIST_MAX_LEVEL) + node := sl.head + + // Go from highest level to lowest + for i := sl.level - 1; i >= 0; i-- { + // store rank that is crossed to reach the insert position + if sl.level-1 == i { + rank[i] = 0 + } else { + rank[i] = rank[i+1] + } + if node.level[i] != nil { + for node.level[i].next != nil && + (node.level[i].next.value.score < score || + (node.level[i].next.value.score == score && // score is the same but the key is different + node.level[i].next.value.member < key)) { + rank[i] += node.level[i].span + node = node.level[i].next + } + } + update[i] = node + } + level := randomLevel() + // add a new level + if level > sl.level { + for i := sl.level; i < level; i++ { + rank[i] = 0 + update[i] = sl.head + update[i].level[i].span = sl.size + } + sl.level = level + } + node = newSkipListNode(level, score, key, value) + + for i := 0; i < level; i++ { + node.level[i].next = update[i].level[i].next + update[i].level[i].next = node + // update span covered by update[i] as newNode is inserted here + node.level[i].span = update[i].level[i].span - (rank[0] - rank[i]) + update[i].level[i].span = (rank[0] - rank[i]) + 1 + } + // increment span for untouched levels + for i := level; i < sl.level; i++ { + update[i].level[i].span++ + } + // update info + if update[0] == sl.head { + node.prev = nil + } else { + node.prev = update[0] + } + if node.level[0].next != nil { + node.level[0].next.prev = node + } else { + sl.tail = node + } + sl.size++ + return node.value +} + +// SkipList is a data structure that allows fast search, insertion, and removal operations. +// Here we define a method delete on it. +// +// The delete method in the skip list will remove nodes that have a given score and key from the skip list. +// If no such nodes are found, the function does nothing. +// +// Parameters: +// +// score: the score of the node to delete. +// key: the key of the node to delete. +func (sl *SkipList) delete(score int, member string) { + + // update: an array of pointers to SkipListNodes; holds the nodes that will have their next pointers updated. + update := make([]*SkipListNode, SKIPLIST_MAX_LEVEL) + + // node: start from the head of our SkipList sl + node := sl.head + + // The code block of "for" loop populates the "update" variable with nodes which reference will change + // due to the removal of the target node. + for i := sl.level; i >= 0; i-- { + // This loop is traversing the SkipList horizontally until it finds a node with a score greater + // than or equal to our target score or if the scores are equal it also checks the member. + for node.level[i].next != nil && + (node.level[i].next.value.score < score || + (node.level[i].next.value.score == score && + node.level[i].next.value.member < member)) { + node = node.level[i].next + } + update[i] = node + } + + // After the traversal, we set the node to point to the possibly (to be) deleted node. + node = node.level[0].next + + // If the possibly deleted node is the target node (it has the same score and member), then remove it. + if node != nil && node.value.score == score && node.value.member == member { + sl.deleteNode(node, update) + } +} +func (sl *SkipList) getRange(start int, end int, reverse bool) (nv []ZSetValue) { + if end > sl.size { + end = sl.size - 1 + } + if start > end { + return + } + if end <= 0 { + return nil // todo unexpected behavior, we can set it to zero as well + } + + node := sl.head + if reverse { + node = sl.tail + if start > 0 { + node = sl.getNodeByRank(sl.size - start) + } + } else { + node = node.level[0].next + if start > 0 { + node = sl.getNodeByRank(start + 1) + } + } + + for i := start; i < end; i++ { + if reverse { + nv = append(nv, *node.value) + node = node.prev + } else { + nv = append(nv, *node.value) + node = node.level[0].next + } + } + return nv +} + +// deleteNode is a method linked to the SkipList struct that allows to remove nodes from the SkipList instance. +// It takes two parameters: a pointer to the node to be deleted, and a slice of pointers to SkipListNode which are required for node updates. +// deleteNode performs the deletion through a two-step process: +// - First, it loops over every level in the SkipList, updating level spans and next node pointers accordingly. +// - Then, it sets the pointers back to the previous node in the data structure and updates the tail and level of the whole list. +// Finally, it decreases the size of the list by one, as a node is being removed from it. +// It doesn't return any value and modifies the SkipList directly. +func (sl *SkipList) deleteNode(node *SkipListNode, updates []*SkipListNode) { + for i := 0; i < sl.level; i++ { + if updates[i].level[i].next == node { + updates[i].level[i].span += node.level[i].span - 1 + updates[i].level[i].next = node.level[i].next + } else { + updates[i].level[i].span-- + } + } + //update backwards + if node.level[0].next != nil { + node.level[0].next.prev = node.prev + } else { + sl.tail = node.prev + } + + for sl.level > 1 && sl.head.level[sl.level-1].next == nil { + sl.level-- + } + sl.size-- +} + +// getRank method receives a SkipList pointer and two parameters: an integer 'score' and a string 'key'. +// It then calculates the rank of an element in the SkipList. The rank is determined based on two conditions: +// - the score of the next node is less than the provided score +// - or, the score of the next node equal to the provided score and the key of the next node is less than or equal to the provided key. +// +// Parameters: +// sl: A pointer to the SkipList object. +// score: The score that we are comparing with the scores in the skiplist. +// key: The key that we are comparing with the keys in the skiplist. +// +// Return: +// Returns the rank of the element in the SkipList if it's found, otherwise returns 0. +func (sl *SkipList) getRank(score int, key string) int { + var rank int + h := sl.head // Start at the head node of the SkipList + + // For loop starts from the top level and goes down to the level 0 + for i := sl.level; i >= 0; i-- { + // While loop advances the 'h' pointer as long as the next node exists and the conditions are fulfilled + for h.level[i].next != nil && + (h.level[i].next.value.score < score || + (h.level[i].next.value.score == score && + h.level[i].next.value.member <= key)) { + + // Increase the rank by the span of the current level + rank += h.level[i].span + // Move to the next node + h = h.level[i].next + } + // If the key of the current node is equal to the provided key, return the rank + if h.value.member == key { + return rank + } + } + // If the element is not found in the SkipList, return 0 + return 0 +} + +// getNodeByRank is a method of the SkipList type that is used to retrieve a node based on its rank within the list. +// The method takes as argument an integer rank and returns a pointer to the SkipListNode at the specified rank, +// or nil if there is no such node. +// +// First, the method initializes a variable traversed to store the cumulative span of nodes traversed thus far in the search. +// It sets a helper variable h to the head of the SkipList, to begin the traversal. +// +// The method then enter a loop that iterates through the levels of the SkipList from the highest down to the base level. +// On each level, while the next node exists and the total span traversed plus the span of the next node doesn't exceed the target rank, +// the method moves to the next node and adds its span to the cumulative span traversed. +// +// If during the traversal the cumulative span equals the target rank, the method returns the getNodeByRank node. +// If the end of the SkipList is reached, or the target rank isn't found on any level, the method returns nil. +func (sl *SkipList) getNodeByRank(rank int) *SkipListNode { + // This variable is used to keep track of the number of nodes we have + // traversed while going through the levels of the SkipList. + var traversed int + + // Define a SkipListNode pointer h, initialized with sl.head + // At the start, this pointer is set to head node of the SkipList. + h := sl.head + + // The outer loop decrements levels from highest level to lowest. + for i := sl.level - 1; i >= 0; i-- { + + // The inner loop traverses the nodes at current level while the next node isn't null and we haven't traversed beyond the 'rank'. + // The traversed variable is also updated to include the span of the current level. + for h.level[i].next != nil && (traversed+h.level[i].span) <= rank { + traversed += h.level[i].span + h = h.level[i].next + } + + // If traversed equals 'rank', it means we've found the node at the rank we are looking for. + // So, return the node. + if traversed == rank { + return h + } + } + + // If the node at 'rank' wasn't found in the SkipList, return nil. + return nil +} + +// ZAdd adds a value with its given score and member to a sorted set (ZSet), associated with +// the provided key. It is a method on the ZSetStructure type. +// +// Parameters: +// +// key: a string that represents the key of the sorted set. +// score: an integer value that determines the order of the added element in the sorted set. +// member: a string used for identifying the added value within the sorted set. +// value: the actual value to be stored within the sorted set. +// +// If the key is an empty string, an error will be returned +func (zs *ZSetStructure) ZAdd(key string, score int, member string, value string) error { + return zs.ZAdds(key, []ZSetValue{{score: score, member: member, value: value}}...) +} + +// ZAdds adds a value with its given score and member to a sorted set (ZSet), associated with +// the provided key. It is a method on the ZSetStructure type. +// +// Parameters: +// +// values: ...ZSetValue multiple values of ZSetValue. +func (zs *ZSetStructure) ZAdds(key string, vals ...ZSetValue) error { + if err := checkKey(key); err != nil { + return err + } + zSet, err := zs.getOrCreateZSet(key) + if err != nil { + return fmt.Errorf("failed to get or create ZSet from DB with key '%v': %w", key, err) + } + + for _, val := range vals { + // if values didn't change, do nothing + if zs.valuesDidntChange(zSet, val.score, val.member, val.value) { + continue + } + if err := zs.updateZSet(zSet, key, val.score, val.member, val.value); err != nil { + return fmt.Errorf("failed to set ZSet to DB with key '%v': %w", key, err) + } + } + + return zs.setZSetToDB(stringToBytesWithKey(key), zSet) +} + +// exists checks if a given member with a specific score exists in a ZSet. It +// also verifies if the provided key is valid. The function returns a boolean +// value indicating whether the member with the specified score exists in the +// ZSet or not. +// +// Parameters: +// +// key (string): Specifies the key of the ZSet. +// score (int): The score of the member to be checked. +// member (string): The specific member to check for in the ZSet. +// +// Returns: +// +// bool: A boolean value indicating whether a member with the specified score +// +// exists in the ZSet or not. Returns false if the ZSet does not exist or if +// the key is invalid. +func (zs *ZSetStructure) exists(key string, score int, member string) bool { + if err := checkKey(key); err != nil { + return false + } + keyBytes := stringToBytesWithKey(key) + + zSet, err := zs.getZSetFromDB(keyBytes) + + if err != nil { + return false + } + return zSet.exists(score, member) +} + +/* +ZRem is a method belonging to ZSetStructure that removes a member from a ZSet. + +Parameters: + - key (string): The key of the ZSet. + - member (string): The member to be removed. + +Returns: + - error: An error if the operation fails. + +The ZRem method checks for a non-empty key, retrieves the corresponding ZSet +from the database, removes the specified member, and then updates +the ZSet in the database. If any point of this operation fails, +the function will return the corresponding error. +*/ +func (zs *ZSetStructure) ZRem(key string, member string) error { + if err := checkKey(key); err != nil { + return err + } + keyBytes := stringToBytesWithKey(key) + + zSet, err := zs.getZSetFromDB(keyBytes) + if err != nil { + return err + } + if err = zSet.RemoveNode(member); err != nil { + return err + } + return zs.setZSetToDB(keyBytes, zSet) +} + +// ZRems method removes one or more specified members from the sorted set that's stored under the provided key. +// Params: +// - key string: the identifier for storing the sorted set in the database. +// - member ...string: a variadic parameter where each argument is a member string to remove. +// +// Returns: error +// +// The function will return an error if it fails at any point, if not it will return nil indicating a successful operation. +func (zs *ZSetStructure) ZRems(key string, member ...string) error { + if err := checkKey(key); err != nil { + return err + } + keyBytes := stringToBytesWithKey(key) + + zSet, err := zs.getZSetFromDB(keyBytes) + + if err != nil { + return fmt.Errorf("failed to get or create ZSet from DB with key '%v': %w", key, err) + } + for _, s := range member { + if err = zSet.RemoveNode(s); err != nil { + return err + } + } + return zs.setZSetToDB(keyBytes, zSet) +} + +// ZScore method retrieves the score associated with the member in a sorted set stored at the key +func (zs *ZSetStructure) ZScore(key string, member string) (int, error) { + if err := checkKey(key); err != nil { + return 0, err + } + keyBytes := stringToBytesWithKey(key) + + zSet, err := zs.getZSetFromDB(keyBytes) + if err != nil { + return 0, err + } + // if the member in the sorted set is found, return the score associated with it + if v, ok := zSet.dict[member]; ok { + return v.score, nil + } + + // if the member doesn't exist in the set, return score of zero and an error + return 0, _const.ErrKeyNotFound +} + +/* +ZRank is a method belonging to the ZSetStructure type. This method retrieves the rank of an element within a sorted set identified by a key. The rank is an integer corresponding to the element's 0-based position in the sorted set when it is arranged in ascending order. + +Parameters: +key (string): The key that identifies the sorted set. +member (string): The element for which you want to find the rank. + +Returns: +int: An integer indicating the rank of the member in the set. + + Rank zero means the member is not found in the set. + +error: If an error occurs, an error object will be returned. + + Possible errors include: + - key is empty + - failure to get or create the ZSet from the DB + - the provided key does not exist in the DB + +Example: +rank, err := zs.ZRank("myKey", "memberName") + + if err != nil { + log.Fatal(err) + } + +fmt.Printf("The rank of '%s' in the set '%s' is %d\n", "memberName", "myKey", rank) +*/ +func (zs *ZSetStructure) ZRank(key string, member string) (int, error) { + if err := checkKey(key); err != nil { + return 0, err + } + keyBytes := stringToBytesWithKey(key) + + zSet, err := zs.getZSetFromDB(keyBytes) + if err != nil { + return 0, fmt.Errorf("failed to get or create ZSet from DB with key '%v': %w", key, err) + } + if v, ok := zSet.dict[member]; ok { + return zSet.skipList.getRank(v.score, member), nil + } + + // rank zero means no rank found + return 0, _const.ErrKeyNotFound +} + +// ZRevRank calculates the reverse rank of a member in a ZSet (Sorted Set) associated with a given key. +// ZSet exploits the Sorted Set data structure of Redis with O(log(N)) time complexity for Fetching the rank. +// +// Parameters: +// +// key: This is a string that serves as the key of a ZSet stored in the database. +// member: This is a string that represents a member of a ZSet whose rank needs to be obtained. +// +// Returns: +// +// int: The integer represents the reverse rank of the member in the ZSet. It returns 0 if the member is not found in the ZSet. +// On successful execution, it returns the difference of the ZSet size and the member's rank. +// error: The error which will be null if no errors occurred. If the key provided is empty, an ErrKeyIsEmpty error is returned. +// If there's a problem getting or creating the ZSet from the database, an error message is returned with the format +// "failed to get or create ZSet from DB with key '%v': %w", where '%v' is the key and '%w' shows the error detail. +// If the member is not found in the ZSet, it returns an ErrKeyNotFound error. +// +// Note: The reverse rank is calculated as 'size - rank', and the ranks start from 1. +func (zs *ZSetStructure) ZRevRank(key string, member string) (int, error) { + if err := checkKey(key); err != nil { + return 0, err + } + keyBytes := stringToBytesWithKey(key) + + zSet, err := zs.getZSetFromDB(keyBytes) + if err != nil { + return 0, fmt.Errorf("failed to get or create ZSet from DB with key '%v': %w", key, err) + } + if v, ok := zSet.dict[member]; ok { + rank := zSet.skipList.getRank(v.score, member) + return (zSet.size) - rank + 1, nil + } + + // rank zero means no rank found + return 0, _const.ErrKeyNotFound +} + +// ZRange retrieves a specific range of elements from a sorted set (ZSet) denoted by a specific key. +// It returns a slice of ZSetValue containing the elements within the specified range (inclusive), and a nil error when successful. +// +// The order of the returned elements is based on their rank in the set, not their score. +// +// Parameters: +// +// key: A string identifier representing the ZSet. The key shouldn't be an empty string. +// start: A zero-based integer representing the first index of the range. +// end: A zero-based integer representing the last index of the range. +// +// Returns: +// +// []ZSetValue: +// Slice of ZSetValue containing elements within the specified range. +// error: +// An error if it occurs during execution, such as: +// 1. The provided key string is empty. +// 2. An error occurs while fetching the ZSet from the database, i.e., the ZSet represented by the given key doesn't exist. +// In the case of an error, an empty slice and the actual error encountered will be returned. +// +// Note: +// On successful execution, ZRange returns the elements starting from 'start' index up to 'end' index inclusive. +// If the set doesn't exist or an error occurs during execution, ZRange returns an empty slice and the error. +// +// Example: +// Assume we have ZSet with the following elements: ["element1", "element2", "element3", "element4"] +// ZRange("someKey", 0, 2) will return ["element1", "element2", "element3"] and nil error. +// +// This method is part of the ZSetStructure type. +func (zs *ZSetStructure) ZRange(key string, start int, end int) ([]ZSetValue, error) { + if err := checkKey(key); err != nil { + return nil, err + } + keyBytes := stringToBytesWithKey(key) + + zSet, err := zs.getZSetFromDB(keyBytes) + if err != nil { + return nil, err + } + r := zSet.skipList.getRange(start, end, false) + + // rank zero means no rank found + return r, nil +} + +// ZCount traverses through the elements of the ZSetStructure based on the given key. +// The count of elements between the range of min and max scores is determined. +// +// The method takes a string as the key and two integers as min and max ranges. +// The range values are inclusive: [min, max]. If min is greater than max, an error is returned. +// The function ignores scores that fall out of the specified min and max range. +// +// It returns the count of elements within the range, and an error if any occurs during the process. +// +// For example, use as follows: +// count, err := zs.ZCount("exampleKey", 10, 50) +// This will count the number of elements that have the scores between 10 and 50 in the ZSetStructure associated with "exampleKey". +// +// Returns: +// 1. int: The total count of elements based on the score range. +// 2. error: Errors that occurred during execution, if any. +func (zs *ZSetStructure) ZCount(key string, min int, max int) (count int, err error) { + if err = checkKey(key); err != nil { + return 0, err + } + keyBytes := stringToBytesWithKey(key) + zSet, err := zs.getZSetFromDB(keyBytes) + if err != nil { + return 0, err + } + if min > max { + return 0, ErrInvalidArgs + } + min, max, err = zs.adjustMinMax(zSet, min, max) + if err != nil { + return 0, err + } + x := zSet.skipList.head + // Node traversal loop. We keep moving to the next node at current level + // as long as the score of the next node's value is less than 'min'. + for i := zSet.skipList.level - 1; i >= 0; i-- { + for x.level[i].next != nil && x.level[i].next.value.score < min { + x = x.level[i].next + } + } + + x = x.level[0].next + // Score range check loop. We traverse nodes and increment 'count' + // as long as node value's score is in the range ['min', 'max'] + for x != nil { + if x.value.score > max { + break + } + count++ + x = x.level[0].next + } + return count, nil +} + +// ZRevRange retrieves a range of elements from a sorted set (ZSet) in descending order. +// Inputs: +// - key: Name of the ZSet +// - startRank: Initial rank of the desired range +// - endRank: Final rank of the desired range +// +// Output: +// - An array of ZSetValue, representing elements from the range [startRank, endRank] in descending order +// - Error if an issue occurs, such as when the key is empty or ZSet retrieval fails +// error +func (zs *ZSetStructure) ZRevRange(key string, startRank int, endRank int) ([]ZSetValue, error) { + if err := checkKey(key); err != nil { + return nil, err + } + keyBytes := stringToBytesWithKey(key) + + zSet, err := zs.getZSetFromDB(keyBytes) + if err != nil { + return nil, err + } + r := zSet.skipList.getRange(startRank, endRank, true) + + // rank zero means no rank found + return r, nil +} + +// The ZCard function returns the size of the dictionary of the sorted set stored at key in the database. +// It takes a string key as an argument. +func (zs *ZSetStructure) ZCard(key string) (int, error) { + if err := checkKey(key); err != nil { + return 0, err + } + keyBytes := stringToBytesWithKey(key) + + zSet, err := zs.getZSetFromDB(keyBytes) + if err != nil { + return 0, err + } + // get the size of the dictionary + return zSet.size, nil +} + +// ZIncrBy increases the score of an existing member in a sorted set stored at specified key by +// the increment `incBy` provided. If member does not exist, ErrKeyNotFound error is returned. +// If the key does not exist, it treats it as an empty sorted set and returns an error. +// +// The method accepts three parameters: +// `key`: a string type parameter that identifies the sorted set +// `member`: a string type parameter representing member in the sorted set +// `incBy`: an int type parameter provides the increment value for a member score +// +// The method throws error under following circumstances - +// if provided key is empty (ErrKeyIsEmpty error), +// if provided key or member is not present in the database (ErrKeyNotFound error), +// if it's unable to fetch or create ZSet from DB, +// if there's an issue with node insertion, +// if unable to set ZSet to DB post increment operation +func (zs *ZSetStructure) ZIncrBy(key string, member string, incBy int) error { + if err := checkKey(key); err != nil { + return err + } + keyBytes := stringToBytesWithKey(key) + + zSet, err := zs.getZSetFromDB(keyBytes) + if err != nil { + return fmt.Errorf("failed to get or create ZSet from DB with key '%v': %w", key, err) + } + + if v, ok := zSet.dict[member]; ok { + if err = zSet.InsertNode(v.score+incBy, member, v.value); err != nil { + return err + } + if err = zs.setZSetToDB(keyBytes, zSet); err != nil { + return err + } + return zs.setZSetToDB(keyBytes, zSet) + } + + return _const.ErrKeyNotFound +} + +// getOrCreateZSet attempts to retrieve a sorted set by a key, or creates a new one if it doesn't exist. +func (zs *ZSetStructure) getOrCreateZSet(key string) (*FZSet, error) { + keyBytes := stringToBytesWithKey(key) + zSet, err := zs.getZSetFromDB(keyBytes) + // if key is not in the DB, create it. + if errors.Is(err, _const.ErrKeyNotFound) { + return newZSetNodes(), nil + } + + return zSet, err +} + +// valuesDidntChange checks if the data of a specific member in a sorted set remained the same. +func (zs *ZSetStructure) valuesDidntChange(zSet *FZSet, score int, member string, value interface{}) bool { + if v, ok := zSet.dict[member]; ok { + return v.score == score && v.member == member && v.value == value + } + + return false +} + +// updateZSet updates or inserts a member in a sorted set and saves the change in storage. +func (zs *ZSetStructure) updateZSet(zSet *FZSet, key string, score int, member string, value interface{}) error { + return zSet.InsertNode(score, member, value) +} + +// InsertNode is a method on the FZSet structure. It inserts a new node +// or updates an existing node in the skip list and the dictionary. +// It takes three parameters: score (an integer), key (a string), +// and value (of any interface type). +// +// If key already exists in the dictionary and the score equals the existing +// score, it updates the value and score in the skip list and the dictionary. +// If the score is different, it only updates the value in the dictionary +// because the ranking doesn't change and there is no need for an update in the +// skip list. +// +// If the key doesn't exist in the dictionary, it adds the new key, value and score +// to the dictionary, increments the size of the dictionary by 1, and also adds +// the node to the skip list. +func (fzs *FZSet) InsertNode(score int, member string, value interface{}) error { + // Instantiate dictionary if it's not already + if fzs.dict == nil { + fzs.dict = make(map[string]*ZSetValue) + } + if fzs.skipList == nil { + fzs.skipList = newSkipList() + } + + // Check if key exists in dictionary + if v, ok := fzs.dict[member]; ok { + if v.score != score { + // Update value and score as the score remains the same + fzs.skipList.delete(score, member) + fzs.dict[member] = fzs.skipList.insert(score, member, value) + } else { + // Ranking isn't altered, only update value + v.value = value + } + } else { // Key doesn't exist, create new key + fzs.dict[member] = fzs.skipList.insert(score, member, value) + fzs.size++ // Increase size count by 1 + // Node is also added to the skip list + } + + // Returns nil as no specific error condition is checked in this function + return nil +} +func (zs *ZSetStructure) adjustMinMax(zSet *FZSet, min int, max int) (adjustedMin int, adjustedMax int, err error) { + if min > max { + return min, max, ErrInvalidArgs + } + minScore, maxScore := zSet.getMinMaxScore() + return zSet.max(min, minScore), zSet.min(max, maxScore), nil +} +func (fzs *FZSet) getMinMaxScore() (minScore int, maxScore int) { + if fzs == nil || fzs.skipList == nil || fzs.skipList.head == nil || len(fzs.skipList.head.level) < 1 || fzs.skipList.head.level[0].next == nil || fzs.skipList.tail == nil { + return 0, 0 + } + + if fzs.skipList.head.level[0].next.value == nil || fzs.skipList.tail.value == nil { + return 0, 0 + } + return fzs.skipList.head.level[0].next.value.score, + fzs.skipList.tail.value.score +} +func (fzs *FZSet) min(a, b int) int { + if a < b { + return a + } + return b +} + +func (fzs *FZSet) max(a, b int) int { + if a > b { + return a + } + return b +} + +// RemoveNode is a method for FZSet structure. +// This method aims to delete a node from both +// the dictionary (dict) and the skip list (skipList). +// +// The method receives one parameter: +// - member: a string that represents the key of the node +// to be removed from the FZSet structure. +// +// The method follows these steps: +// 1. Check if a node with key 'member' exists in the dictionary. +// If not, or if the dictionary itself is nil, it returns an error +// (_const.ErrKeyNotFound) indicating that the node cannot be found. +// 2. If the node exists, it proceeds to remove the node from both the +// skip list and dictionary. +// 3. After the successful removal of the node, it returns nil indicating +// the success of the operation. +// +// The RemoveNode's primary purpose is to provide a way to securely and +// efficiently remove a node from the FZSet structure. +func (fzs *FZSet) RemoveNode(member string) error { + // Check for existence of key in dictionary + v, ok := fzs.dict[member] + if !ok || fzs.dict == nil { + return _const.ErrKeyNotFound + } + + // Delete Node from the skip list and dictionary + fzs.skipList.delete(v.score, member) + delete(fzs.dict, member) + fzs.size-- + + return nil +} + +func (fzs *FZSet) exists(score int, member string) bool { + v, ok := fzs.dict[member] + + return ok && v.score == score +} + +// Bytes encodes the FZSet instance into bytes using MessagePack +// binary serialization format. The encoded bytes can be used for +// storage or transmission. If the encoding operation fails, an +// error is returned. +func (fzs *FZSet) Bytes() ([]byte, error) { + var msgPack = encoding.NewMessagePackEncoder() + if encodingError := msgPack.Encode(fzs); encodingError != nil { + return nil, encodingError + } + return msgPack.Bytes(), nil +} + +// FromBytes decodes the input byte slice into the FZSet object using MessagePack. +// Returns an error if decoding fails, otherwise nil. +func (fzs *FZSet) FromBytes(b []byte) error { + return encoding.NewMessagePackDecoder(b).Decode(fzs) +} + +// getZSetFromDB fetches and deserializes FZSet from the database. +// +// Returns a pointer to the FZSet and error, if any. +// If the key doesn't exist, both the pointer and the error will be nil. +// In case of deserialization errors, returns nil and the error. +func (zs *ZSetStructure) getZSetFromDB(key []byte) (*FZSet, error) { + dbData, err := zs.db.Get(key) + + // If key is not found, return nil for both; otherwise return the error. + if err != nil { + + return nil, err + } + dec := encoding.NewMessagePackDecoder(dbData) + // Deserialize the data. + var zSetValue FZSet + if err = dec.Decode(&zSetValue); err != nil { + return nil, err + } + + // return a pointer to the deserialized FZSet, nil for the error + return &zSetValue, nil +} + +// checkKey function that accepts a string parameter key +// and returns error if key is empty. +// +// # It returns nil otherwise +// +// Parameters: +// +// key : A string that is checked if empty +// +// Returns: +// +// error : _const.ErrKeyIsEmpty if key is empty, nil otherwise +func checkKey(key string) error { + if len(key) == 0 { + return _const.ErrKeyIsEmpty + } + return nil +} + +// setZSetToDB writes a FZSet object to the database. +// +// parameters: +// key: This is a byte slice that is used as a key in the database. +// zSetValue: This is a pointer to a FZSet object that needs to be stored in the database. +// +// The function serializes the FZSet object into MessagePack format. If an error occurs +// either during serialization or when writing to the database, that specific error is returned. +// If the process is successful, it returns nil. +func (zs *ZSetStructure) setZSetToDB(key []byte, zSetValue *FZSet) error { + val := encoding.NewMessagePackEncoder() + err := val.Encode(zSetValue) + if err != nil { + return err + } + return zs.db.Put(key, val.Bytes()) +} + +// UnmarshalBinary de-serializes the given byte slice into FZSet instance +// it uses MessagePack format for de-serialization +// Returns an error if the decoding of size or insertion of node fails. +// +// Parameters: +// data : a slice of bytes to be decoded +// +// Returns: +// An error that will be nil if the function succeeds. +func (fzs *FZSet) UnmarshalBinary(data []byte) (err error) { + // NewMessagePackDecoder creates a new MessagePack decoder with the provided data + dec := encoding.NewMessagePackDecoder(data) + + var size int + // Decode the size of the data structure + if err = dec.Decode(&size); err != nil { + return err // error handling if something goes wrong with decoding + } + + // Iterate through each node in the data structure + for i := 0; i < size; i++ { + // Create an empty instance of ZSetValue for each node + slValue := ZSetValue{} + + // Decode each node onto the empty ZSetValue instance + if err = dec.Decode(&slValue); err != nil { + return err // error handling if something goes wrong with decoding + } + + // Insert the decoded node into the FZSet instance + if err = fzs.InsertNode(slValue.score, slValue.member, slValue.value); err != nil { + return err + } + } + return // if all nodes are correctly decoded and inserted, return with nil error +} + +// MarshalBinary serializes the FZSet instance into a byte slice. +// It uses MessagePack format for serialization +// Returns the serialized byte slice and an error if the encoding fails. +func (fzs *FZSet) MarshalBinary() (_ []byte, err error) { + + // Initializing the MessagePackEncoder + enc := encoding.NewMessagePackEncoder() + + // Encoding the size attribute of d (i.e., d.size). The operation could fail, thus we check for an error. + // An error, if occurred, will be returned immediately, hence the flow of execution stops here. + err = enc.Encode(fzs.size) + if err != nil { + return nil, err + } + + // This is the start of a loop going over all the nodes in d's skip list from the tail of the + // list to the head. + // The tail and head pointers refer to the last and first element of the list, respectively, + // and are maintained for efficient traversing of the list. + // we do that to get the elements in reverse order from biggest to the smallest for the best + // insertion efficiency as it makes the insertion O(1), because each new element to be inserted is + // the smallest yet. + x := fzs.skipList.tail + // as long as there are elements in the SkipList continue + for x != nil { + // Encoding the value of the current node in the skip list + // Again, if an error occurs it gets immediately returned, thus breaking the loop. + err = enc.Encode(x.value) + if err != nil { + return nil, err + } + + // Move to the previous node in the skip list. + x = x.prev + } + + // After the traversal of the skip list, the encoder should now hold the serialized representation of the + // FZSet. Now, we return the bytes from the encoder along with any error that might have occurred + // during the encoding (should be nil if everything went fine). + return enc.Bytes(), err +} + +// UnmarshalBinary de-serializes the given byte slice into ZSetValue instance +// It uses the MessagePack format for de-serialization +// Returns an error if the decoding of Key, Score, or Value fails. +func (p *ZSetValue) UnmarshalBinary(data []byte) (err error) { + dec := encoding.NewMessagePackDecoder(data) + if err = dec.Decode(&p.member); err != nil { + return + } + if err = dec.Decode(&p.score); err != nil { + return err + } + if err = dec.Decode(&p.value); err != nil { + return + } + return +} + +// MarshalBinary uses MessagePack as the encoding format to serialize +// the ZSetValue object into a byte array. +func (d *ZSetValue) MarshalBinary() (_ []byte, err error) { + + // The NewMessagePackEncoder function is called to create a new + // MessagePack encoder. + enc := encoding.NewMessagePackEncoder() + + // Then, we try to encode the 'key' field of the ZSetValue + // If an error occurs, it is returned immediately along with the + // currently encoded byte slice. + if err = enc.Encode(d.member); err != nil { + return enc.Bytes(), err + } + + // We do the same for the 'score' field. + if err = enc.Encode(d.score); err != nil { + return enc.Bytes(), err + } + + // Lastly, the 'value' field is encoded in the same way. + if err = enc.Encode(d.value); err != nil { + return enc.Bytes(), err + } + + // If everything goes well and we're done encoding, we return the + // final byte slice which represents the encoded ZSetValue + // and a nil error. + return enc.Bytes(), err +} diff --git a/structure/zset_test.go b/structure/zset_test.go new file mode 100644 index 00000000..8cc0a83f --- /dev/null +++ b/structure/zset_test.go @@ -0,0 +1,1039 @@ +package structure + +import ( + "github.com/ByteStorage/FlyDB/config" + _const "github.com/ByteStorage/FlyDB/lib/const" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "os" + "reflect" + "testing" +) + +func initZSetDB() (*ZSetStructure, *config.Options) { + opts := config.DefaultOptions + dir, _ := os.MkdirTemp("", "TestZSetStructure") + opts.DirPath = dir + zs, _ := NewZSetStructure(opts) + return zs, &opts +} + +func TestSortedSet(t *testing.T) { + type test struct { + name string + input map[string]int + want *FZSet + expectError bool + } + + zs := newZSetNodes() + err := zs.InsertNode(3, "banana", "hello") + err = zs.InsertNode(1, "apple", "hello") + err = zs.InsertNode(2, "pear", "hello") + err = zs.InsertNode(44, "orange", "hello") + err = zs.InsertNode(9, "strawberry", "delish") + err = zs.InsertNode(15, "dragon-fruit", "nonDelish") + b, err := zs.Bytes() + + fromBytes := newZSetNodes() + err = fromBytes.FromBytes(b) + assert.NoError(t, err) + assert.NotNil(t, fromBytes.skipList) + assert.Equal(t, fromBytes.size, zs.size) + tests := []test{ + { + name: "empty", + input: map[string]int{}, + want: &FZSet{}, + expectError: false, + }, + { + name: "three fruits", + input: map[string]int{"banana": 3, "apple": 2, "pear": 4, "peach": 40}, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.ElementsMatch(t, tt.want, nil) + }) + } + +} + +func TestSortedSet_Bytes(t *testing.T) { + +} + +func TestZRem(t *testing.T) { + + type testCase struct { + name string + key string + setup func(z *ZSetStructure) + members []string + want []string + dontWant []string + err error + } + + testCases := []testCase{ + { + name: "key empty", + setup: func(z *ZSetStructure) { + _ = z.ZAdds("key1", []ZSetValue{{}}...) + }, + members: []string{""}, + err: _const.ErrKeyIsEmpty, + }, + { + name: "key not found", + setup: func(z *ZSetStructure) { + _ = z.ZAdds("key1", []ZSetValue{{}}...) + }, + key: "notfound", + members: []string{""}, + err: _const.ErrKeyNotFound, + }, + { + name: "member not found", + setup: func(z *ZSetStructure) { + _ = z.ZAdds("key1", []ZSetValue{{}}...) + }, + key: "key1", + members: []string{"notfound"}, + err: _const.ErrKeyNotFound, + }, + { + name: "member empty", + setup: func(z *ZSetStructure) { + _ = z.ZAdds("key1", []ZSetValue{{score: 1, member: "mem1", value: ""}}...) + }, + key: "key1", + members: []string{""}, + err: _const.ErrKeyNotFound, + }, + { + name: "remove half members", + setup: func(z *ZSetStructure) { + _ = z.ZAdds("key1", []ZSetValue{ + {score: 1, member: "mem1", value: ""}, + {score: 2, member: "mem2", value: ""}, + {score: 3, member: "mem3", value: ""}, + {score: 4, member: "mem4", value: ""}}...) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockZSetStructure, _ := initZSetDB() + tc.setup(mockZSetStructure) + + for _, m := range tc.members { + err = mockZSetStructure.ZRem(tc.key, m) + assert.EqualError(t, err, tc.err.Error()) + } + if tc.err == nil { + // check if member added + } + + }) + } +} +func TestZRems(t *testing.T) { + mockZSetStructure, _ := initZSetDB() + + // 1. Test for Key is Empty + err := mockZSetStructure.ZRems("", "member") + require.Error(t, err) + require.Equal(t, _const.ErrKeyIsEmpty, err) + type testCase struct { + key string + input []ZSetValue + rems []string + want []ZSetValue + err error + } + + testCases := []testCase{ + {"key", + []ZSetValue{ + {score: 0, member: "mem0", value: ""}, + {score: 1, member: "mem1", value: ""}, + {score: 2, member: "mem2", value: ""}, + {score: 3, member: "mem3", value: ""}, + {score: 4, member: "mem4", value: ""}, + {score: 5, member: "mem5", value: ""}, + {score: 6, member: "mem6", value: ""}, + }, + []string{ + "mem0", + "mem1", + "mem6", + }, []ZSetValue{ + {score: 2, member: "mem2", value: ""}, + {score: 3, member: "mem3", value: ""}, + {score: 4, member: "mem4", value: ""}, + {score: 5, member: "mem5", value: ""}, + }, nil}, + {"", + []ZSetValue{ + {score: 0, member: "mem0", value: ""}, + {score: 1, member: "mem1", value: ""}, + }, + []string{ + "mem0", + "mem1", + "mem2", + "mem3", + "mem4", + "mem5", + "mem6", + }, + []ZSetValue{}, + _const.ErrKeyIsEmpty}, + { + "Key1", + []ZSetValue{ + {score: 0, member: "mem0", value: ""}, + {score: 1, member: "mem1", value: ""}, + }, + []string{ + "mem0", + "mem1", + "mem2", + "mem3", + "mem4", + "mem5", + "mem6", + }, []ZSetValue{}, _const.ErrKeyNotFound}, + } + + for _, tc := range testCases { + t.Run(tc.key, func(t *testing.T) { + _ = mockZSetStructure.ZAdds(tc.key, tc.input...) + + //remove all the elements + err = mockZSetStructure.ZRems(tc.key, tc.rems...) + assert.Equal(t, tc.err, err) + //validate + for _, value := range tc.want { + te := mockZSetStructure.exists(tc.key, value.score, value.member) + assert.True(t, te) + } + }) + } +} +func TestZAdd(t *testing.T) { + zs, _ := initZSetDB() + type testCase struct { + key string + score int + member string + value string + want ZSetValue + err error + } + + testCases := []testCase{ + { + "key", + 10, + "member", + "value", + ZSetValue{member: "member"}, + nil, + }, + { + "", + 10, + "member", + "value", + ZSetValue{member: ""}, + _const.ErrKeyIsEmpty, + }, + } + + for _, tc := range testCases { + t.Run(tc.key, func(t *testing.T) { + err := zs.ZAdd(tc.key, tc.score, tc.member, tc.value) + assert.Equal(t, tc.err, err) + if tc.err == nil { + // check if member added + assert.True(t, zs.exists(tc.key, tc.score, tc.member)) + err = zs.ZRem(tc.key, tc.member) + assert.NoError(t, err) + // should be removed successfully + assert.False(t, zs.exists(tc.key, tc.score, tc.member)) + } + // Adjust according to your error handling + + }) + } +} +func TestZAdds(t *testing.T) { + zs, _ := initZSetDB() + + // 1. Test for Key is Empty + err := zs.ZAdds("", []ZSetValue{}...) + require.Error(t, err) + require.Equal(t, _const.ErrKeyIsEmpty, err) + type testCase struct { + key string + input []ZSetValue + want []ZSetValue + err error + } + + testCases := []testCase{ + {"key", + []ZSetValue{ + {score: 0, member: "mem0", value: ""}, + {score: 1, member: "mem1", value: ""}, + {score: 2, member: "mem2", value: ""}, + {score: 3, member: "mem3", value: ""}, + {score: 3, member: "mem3", value: ""}, + {score: 4, member: "mem4", value: ""}, + {score: 5, member: "mem5", value: ""}, + {score: 6, member: "mem6", value: ""}, + }, + []ZSetValue{ + {score: 0, member: "mem0", value: ""}, + {score: 1, member: "mem1", value: ""}, + {score: 2, member: "mem2", value: ""}, + {score: 3, member: "mem3", value: ""}, + {score: 3, member: "mem3", value: ""}, + {score: 4, member: "mem4", value: ""}, + {score: 5, member: "mem5", value: ""}, + {score: 6, member: "mem6", value: ""}, + }, + nil}, + {"", + []ZSetValue{ + {score: 0, member: "mem0", value: ""}, + {score: 1, member: "mem1", value: ""}, + }, + []ZSetValue{}, + _const.ErrKeyIsEmpty, + }, + } + + for _, tc := range testCases { + t.Run(tc.key, func(t *testing.T) { + err = zs.ZAdds(tc.key, tc.input...) + assert.Equal(t, tc.err, err) + //validate + for _, value := range tc.want { + te := zs.exists(tc.key, value.score, value.member) + assert.True(t, te) + } + }) + } +} +func TestZIncrBy(t *testing.T) { + zs, _ := initZSetDB() + err := zs.ZIncrBy("", "non-existingMember", 5) + if err == nil { + t.Error("Expected error for empty key not returned") + } + + err = zs.ZIncrBy("key", "non-existingMember", 5) + if !assert.ErrorIs(t, err, _const.ErrKeyNotFound) { + t.Error("Expected ErrKeyNotFound for non-existing member not returned") + } + err = zs.ZAdd("key", 1, "existingMember", "") + assert.NoError(t, err) + err = zs.ZIncrBy("key", "existingMember", 5) + assert.NoError(t, err) + + Zset, err := zs.getZSetFromDB(stringToBytesWithKey("key")) + assert.Equal(t, 6, Zset.dict["existingMember"].score) +} +func TestZRank(t *testing.T) { + zs, _ := initZSetDB() + + // Assume that ZAdd adds a member to a set and assigns the member a score. + // Here the score does not matter + err := zs.ZAdd("myKey", 1, "member1", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 2, "member2", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 3, "member3", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 4, "member4", "") + assert.NoError(t, err) + + // Test when member is present in the set + rank, err := zs.ZRank("myKey", "member1") + assert.NoError(t, err) // no error should occur + assert.Equal(t, 1, rank) // as we inserted 'member1' first, its rank should be 1 + + // Test when member is not present in the set + rank, err = zs.ZRank("myKey", "unavailableMember") + assert.Error(t, err) // an error should occur + assert.Equal(t, 0, rank) // as 'unavailableMember' is not part of set, rank should be 0 + + // Test with an empty key + rank, err = zs.ZRank("", "member") + assert.Error(t, err) // an error should occur + assert.Equal(t, 0, rank) // rank should be 0 for invalid key} + + // Test member2 which should be 2nd + rank, err = zs.ZRank("myKey", "member2") + assert.NoError(t, err) // there should be no errors + assert.Equal(t, 2, rank) // rank should be 2 for key `member2` + + // Test member3 which should be 3rd + rank, err = zs.ZRank("myKey", "member3") + assert.NoError(t, err) // there should be no errors + assert.Equal(t, 3, rank) + + // remove member2 and test `member3` which should become 2 + err = zs.ZRem("myKey", "member2") + assert.NoError(t, err) // there should be no errors + rank, err = zs.ZRank("myKey", "member3") + assert.NoError(t, err) // there should be no errors + assert.Equal(t, 2, rank) // now `member3` should become 2nd +} +func TestZRevRank(t *testing.T) { + zs, _ := initZSetDB() + + // Assume that ZAdd adds a member to a set and assigns the member a score. + // Here the score does not matter + err := zs.ZAdd("myKey", 1, "member1", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 2, "member2", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 3, "member3", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 4, "member4", "") + assert.NoError(t, err) + + // Test when member is present in the set + rank, err := zs.ZRevRank("myKey", "member3") + assert.NoError(t, err) // no error should occur + assert.Equal(t, 2, rank) // as we inserted 'member1' first, its rank should be 1 + + // Test when member is not present in the set + rank, err = zs.ZRevRank("myKey", "unavailableMember") + assert.Error(t, err) // an error should occur + assert.Equal(t, 0, rank) // as 'unavailableMember' is not part of set, rank should be 0 + + // Test with an empty key + rank, err = zs.ZRevRank("", "member") + assert.Error(t, err) // an error should occur + assert.Equal(t, 0, rank) // rank should be 0 for invalid key} + + // Test member2 which should be 2nd + rank, err = zs.ZRevRank("myKey", "member1") + assert.NoError(t, err) // there should be no errors + assert.Equal(t, 4, rank) // rank should be 2 for key `member2` + + // Test member3 which should be 3rd + rank, err = zs.ZRevRank("myKey", "member4") + assert.NoError(t, err) // there should be no errors + assert.Equal(t, 1, rank) + + // remove member2 and test `member3` which should become 2 + err = zs.ZRem("myKey", "member2") + assert.NoError(t, err) // there should be no errors + rank, err = zs.ZRevRank("myKey", "member3") + assert.NoError(t, err) // there should be no errors + assert.Equal(t, 2, rank) // now `member3` should become 2nd +} +func TestZRevRange(t *testing.T) { + zs, _ := initZSetDB() + + err := zs.ZAdd("myKey", 1, "member1", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 2, "member2", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 3, "member3", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 4, "member4", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 5, "member5", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 6, "member6", "") + assert.NoError(t, err) + + var n []uint8 + tests := []struct { + key string + start int + end int + want []ZSetValue + wantErr error + }{ + {"myKey", 0, 3, []ZSetValue{ + {6, "member6", n}, + {5, "member5", n}, + {4, "member4", n}, + }, nil}, + {"", 0, 2, nil, _const.ErrKeyIsEmpty}, + {"fail", 0, 2, nil, _const.ErrKeyNotFound}, + } + + for _, tt := range tests { + t.Run(tt.key, func(t *testing.T) { + got, err := zs.ZRevRange(tt.key, tt.start, tt.end) + + if !reflect.DeepEqual(err, tt.wantErr) { + t.Errorf("ZRange() error = %v, wantErr %v", err, tt.wantErr) + } + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ZRange() = %v, want %v", got, tt.want) + } + }) + } +} +func TestZRange(t *testing.T) { + zs, _ := initZSetDB() + + err := zs.ZAdd("myKey", 1, "member1", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 2, "member2", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 3, "member3", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 4, "member4", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 5, "member5", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 6, "member6", "") + assert.NoError(t, err) + var n []uint8 + tests := []struct { + key string + start int + end int + want []ZSetValue + wantErr error + }{ + {"myKey", 0, 3, []ZSetValue{ + {1, "member1", n}, + {2, "member2", n}, + {3, "member3", n}, + }, nil}, + {"", 0, 2, nil, _const.ErrKeyIsEmpty}, + {"fail", 0, 2, nil, _const.ErrKeyNotFound}, + } + + for _, tt := range tests { + t.Run(tt.key, func(t *testing.T) { + got, err := zs.ZRange(tt.key, tt.start, tt.end) + + if !reflect.DeepEqual(err, tt.wantErr) { + t.Errorf("ZRange() error = %v, wantErr %v", err, tt.wantErr) + } + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ZRange() = %v, want %v", got, tt.want) + } + }) + } +} +func TestZCard(t *testing.T) { + zs, _ := initZSetDB() + + err := zs.ZAdd("myKey", 1, "member1", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 2, "member2", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 3, "member3", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 4, "member4", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 5, "member5", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 6, "member6", "") + assert.NoError(t, err) + tests := []struct { + name string + key string + want int + wantErr error + }{ + {"Empty Key", "", 0, _const.ErrKeyIsEmpty}, + {"Non-Existent Key", "nonExist", 0, _const.ErrKeyNotFound}, + {"Existing Key", "myKey", 6, nil}, + } + + for _, tt := range tests { + t.Run(tt.key, func(t *testing.T) { + got, err := zs.ZCard(tt.key) + + if tt.want != got { + t.Fatalf("expected %d, got %d", tt.want, got) + } + + if tt.wantErr != nil { + if err == nil || tt.wantErr.Error() != err.Error() { + t.Fatalf("expected error '%v', got '%v'", tt.wantErr, err) + } + + } else if err != nil { + t.Fatalf("expected no error, got error '%v'", err) + } + }) + } +} +func TestZScore(t *testing.T) { + zs, _ := initZSetDB() + + err := zs.ZAdd("myKey", 1, "member1", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 2, "member2", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 3, "member3", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 4, "member4", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 5, "member5", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 6, "member6", "") + assert.NoError(t, err) + tests := []struct { + expectError error + expectedScore int + key string + member string + }{ + {_const.ErrKeyIsEmpty, 0, "", "member1"}, + {_const.ErrKeyNotFound, 0, "key1", "foo"}, + {nil, 1, "myKey", "member1"}, + {nil, 2, "myKey", "member2"}, + } + + for _, test := range tests { + t.Run(test.key, func(t *testing.T) { + score, err := zs.ZScore(test.key, test.member) + assert.Equal(t, test.expectError, err) + assert.Equal(t, test.expectedScore, score) + }) + } +} +func TestNewSkipList(t *testing.T) { + s := newSkipList() + + assertions := assert.New(t) + assertions.Equal(1, s.level) + assertions.Nil(s.head.prev) + assertions.Equal(0, s.head.value.score) + assertions.Equal("", s.head.value.member) +} + +func TestNewSkipListNode(t *testing.T) { + score := 10 + key := "test_key" + value := "test_value" + level := 5 + + node := newSkipListNode(level, score, key, value) + + // Validate node's value + if node.value.score != score || node.value.member != key || node.value.value != value { + t.Errorf("Unexpected value in node, got: %v, want: {score: %d, key: %s, val: %s}.\n", node.value, score, key, value) + } + + // Validate node's level slice length + if len(node.level) != level { + t.Errorf("Unexpected length of node's level slice, got: %d, want: %d.\n", len(node.level), level) + } + + // Validate each SkipListLevel in the level slice + for _, l := range node.level { + if l.next != nil || l.span != 0 { + t.Errorf("Unexpected SkipListLevel, got: %v, want: {forward: nil, span: 0}.\n", l) + } + } +} + +func TestSkipList_delete(t *testing.T) { + type deleteTest struct { + name string + score int + member string + targetList []testZSetNodeValue + inputList []testZSetNodeValue + } + + vals := []testZSetNodeValue{ + {score: 1, member: "mem1", value: nil}, + {score: 2, member: "mem2", value: nil}, + {score: 3, member: "mem3", value: nil}, + {score: 4, member: "mem4", value: nil}, + {score: 5, member: "mem5", value: nil}, + } + + // Omitted: Add some nodes into sl... + + tests := []deleteTest{ + { + name: "Delete Test 1", + score: 15, + member: "member1", + targetList: []testZSetNodeValue{{score: 3, member: "mem3"}}, // result of adding nodes into sl + inputList: vals, + }, + // Add more test cases here... + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + head := newZSetNodes() + populateSkipListFromSlice(head, test.inputList) + + for _, value := range test.targetList { + // check if the insertion has been performed + assert.True(t, head.exists(value.score, value.member)) + // delete the target members + assert.NoError(t, head.RemoveNode(value.member)) + // check to see if the deletion has been correctly performed + assert.False(t, head.exists(value.score, value.member)) + + } + }) + } +} + +type testZSetNodeValue struct { + score int + member string + value interface{} +} + +func populateSkipListFromSlice(nodes *FZSet, zSetNodeValues []testZSetNodeValue) { + // Iterate over the zsetNodes array + for _, zSetNode := range zSetNodeValues { + _ = nodes.InsertNode(zSetNode.score, zSetNode.member, zSetNode.value) + } +} +func TestRandomLevel(t *testing.T) { + for i := 0; i < 1000; i++ { + level := randomLevel() + if level < 1 || level > SKIPLIST_MAX_LEVEL { + t.Errorf("Generated level out of range: %v", level) + } + } +} +func TestZSetNodes_InsertNode(t *testing.T) { + pq := &FZSet{} + + // Case 1: Insert new node + err := pq.InsertNode(1, "test", "value") + if err != nil { + t.Error("Failed when inserting a new node") + } + + if _, ok := pq.dict["test"]; !ok { + t.Error("Insert node failed, expected key to exist in dictionary") + } + + // Case 2: Update existing node with same score + err = pq.InsertNode(1, "test", "newvalue") + if err != nil { + t.Error("Failed when updating a score with same value") + } + + if v, ok := pq.dict["test"]; !ok || v.value != "newvalue" { + t.Error("Update node failed, expected value to be updated") + } + + // Case 3: Insert node with existing key but different score + err = pq.InsertNode(2, "test", "newvalue") + if err != nil { + t.Error("Failed when updating a score with a new value") + } + + if v, ok := pq.dict["test"]; !ok || v.score != 2 { + t.Error("Update node failed, expected score to be updated") + } +} +func TestZCount(t *testing.T) { + zs, _ := initZSetDB() + + tests := []struct { + key string + input []testZSetNodeValue + min int + max int + want int + err error + }{ + { + "test1", + []testZSetNodeValue{ + {score: 0, member: "mem0", value: ""}, + {score: 1, member: "mem1", value: ""}, + {score: 2, member: "mem2", value: ""}, + {score: 3, member: "mem3", value: ""}, + {score: 4, member: "mem4", value: ""}, + {score: 5, member: "mem5", value: ""}, + {score: 6, member: "mem6", value: ""}, + }, + 1, 5, 5, nil, + }, + { + "test2", + []testZSetNodeValue{ + {score: 0, member: "mem0", value: ""}, + {score: 1, member: "mem1", value: ""}, + {score: 2, member: "mem2", value: ""}, + {score: 3, member: "mem3", value: ""}, + {score: 4, member: "mem4", value: ""}, + {score: 5, member: "mem5", value: ""}, + {score: 6, member: "mem6", value: ""}, + }, + 0, 5, 6, nil, + }, + { + "test3", + []testZSetNodeValue{ + {score: 0, member: "mem0", value: ""}, + {score: 1, member: "mem1", value: ""}, + {score: 2, member: "mem2", value: ""}, + {score: 3, member: "mem3", value: ""}, + {score: 4, member: "mem4", value: ""}, + {score: 5, member: "mem5", value: ""}, + {score: 6, member: "mem6", value: ""}, + }, + 1, 3, 3, nil, + }, + { + "test4", + []testZSetNodeValue{ + {score: 0, member: "mem0", value: ""}, + {score: 1, member: "mem1", value: ""}, + {score: 2, member: "mem2", value: ""}, + {score: 3, member: "mem3", value: ""}, + {score: 4, member: "mem4", value: ""}, + {score: 5, member: "mem5", value: ""}, + {score: 6, member: "mem6", value: ""}, + }, + 2, 2, 1, nil, + }, + { + "test5", + []testZSetNodeValue{ + {score: 3, member: "mem3", value: ""}, + }, + 10, 20, 0, nil, + }, + { + "test6", + []testZSetNodeValue{ + {score: 3, member: "mem3", value: ""}, + }, + 10, 5, 0, ErrInvalidArgs, + }, + { + "", + []testZSetNodeValue{ + {score: 3, member: "mem3", value: ""}, + }, + 10, 5, 0, _const.ErrKeyIsEmpty, + }, + } + + for _, tt := range tests { + t.Run(tt.key, func(t *testing.T) { + for _, value := range tt.input { + _ = zs.ZAdd(tt.key, value.score, value.member, value.value.(string)) + } + got, err := zs.ZCount(tt.key, tt.min, tt.max) + if got != tt.want { + t.Errorf("TestZCount(%v, %v, %v) = %v, want: %v", tt.key, tt.min, tt.max, got, tt.want) + } + if err != nil && err.Error() != tt.err.Error() { + t.Errorf("TestZCount(%v, %v, %v) returned unexpected error: got %v, want: %v", tt.key, tt.min, tt.max, err, tt.err) + } + }) + } +} + +func TestFZSetMinMax(t *testing.T) { + fzs := &FZSet{ + skipList: newSkipList(), + } + _ = fzs.InsertNode(1, "mem1", "") + _ = fzs.InsertNode(100, "mem2", "") + + minScore, maxScore := fzs.getMinMaxScore() + + if minScore != 1 || maxScore != 100 { + t.Errorf("getMinMaxScore() = %d, %d, want: 1, 100", minScore, maxScore) + } + + if min := fzs.min(5, 10); min != 5 { + t.Errorf("min(5, 10) = %d, want: 5", min) + } + + if max := fzs.max(5, 10); max != 10 { + t.Errorf("max(5, 10) = %d, want: 10", max) + } + + // if case part of skip list is missing, we should return 0,0 + fzs.skipList.tail.value = nil + minScore, maxScore = fzs.getMinMaxScore() + assert.Equal(t, minScore, 0) + assert.Equal(t, maxScore, 0) + // again, an error case + fzs.skipList = nil + minScore, maxScore = fzs.getMinMaxScore() + assert.Equal(t, minScore, 0) + assert.Equal(t, maxScore, 0) + +} +func TestZSetStructure_adjustMinMax(t *testing.T) { + zss, _ := NewZSetStructure(config.DefaultOptions) + fz := newZSetNodes() + + _, _, err := zss.adjustMinMax(fz, 100, 0) + assert.Equal(t, ErrInvalidArgs, err) + // + _ = fz.InsertNode(30, "mem1", "") + _ = fz.InsertNode(200, "mem1", "") + minScore, maxScore, err := zss.adjustMinMax(fz, 10, 50) + assert.NoError(t, err) + // as the min now is 30, our provided min of 10 will be turned into 30 + // as our param of max is 50 and maximum score is 200, it won't change + assert.Equal(t, 30, minScore) + assert.Equal(t, 50, maxScore) +} + +func TestZset_getNodeByRank(t *testing.T) { + sl := newSkipList() + sl.insert(1, "mem1", "") + sl.insert(2, "mem2", "") + sl.insert(3, "mem3", "") + tests := []struct { + name string + rank int + want *ZSetValue // Expected Output, use your actual SkipListNode instance or null here + }{ + { + name: "Case 1: Get Node by Rank 1", + rank: 1, + want: &ZSetValue{score: 1, member: "mem1", value: ""}, + }, + { + name: "Case 2: Get Node by Rank 2", + rank: 2, + want: &ZSetValue{score: 2, member: "mem2", value: ""}, + }, + { + name: "Case 3: Get Node by Non-existed Rank", + rank: 9999, + want: nil, // should return nil if rank doesn't exist + }, + } + + // Iterate over test cases + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := sl.getNodeByRank(tt.rank) + if tt.want == nil { + assert.Nil(t, got) + + } else { + assert.Equal(t, tt.want, got.value) + } + }) + } +} +func Test_exists(t *testing.T) { + zs, _ := initZSetDB() + + tt := []struct { + key string + score int + member string + want bool + }{ + { + key: "", + score: 1, + member: "", + want: false, + }, + { + key: "key1", + score: 1, + member: "", + want: false, + }, + } + + for _, tc := range tt { + t.Run("", func(t *testing.T) { + got := zs.exists(tc.key, tc.score, tc.member) + + if got != tc.want { + t.Errorf("exists() = %v, want %v", got, tc.want) + } + }) + } +} +func TestNewZSetStructure(t *testing.T) { + tt := []struct { + name string + setup func() (*ZSetStructure, error) + wantErr error + }{ + { + name: "init no error", + setup: func() (*ZSetStructure, error) { + opts := config.DefaultOptions + dir, _ := os.MkdirTemp("", "TestZSetStructure") + opts.DirPath = dir + return NewZSetStructure(opts) + }, + wantErr: nil, + }, + { + name: "init with error wrong path", + setup: func() (*ZSetStructure, error) { + opts := config.DefaultOptions + opts.DirPath = "" + return NewZSetStructure(opts) + }, + wantErr: _const.ErrOptionDirPathIsEmpty, + }, + } + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + _, err := tc.setup() + assert.Equal(t, tc.wantErr, err) + }) + } +}