From fa520ed97f0b3266cbd551d3a67d52a78dd4bf4a Mon Sep 17 00:00:00 2001 From: Saeid Aghapour Date: Thu, 13 Jul 2023 21:13:15 +0330 Subject: [PATCH 1/5] implements stable store component interfaces(#179) --- lib/encoding/messagepack.go | 60 +++++++++++++ structure/zset.go | 169 ++++++++++++++++++++++++++++++++++++ structure/zset_test.go | 61 +++++++++++++ 3 files changed, 290 insertions(+) create mode 100644 structure/zset.go create mode 100644 structure/zset_test.go diff --git a/lib/encoding/messagepack.go b/lib/encoding/messagepack.go index ed8eccf3..11a23321 100644 --- a/lib/encoding/messagepack.go +++ b/lib/encoding/messagepack.go @@ -2,9 +2,46 @@ package encoding import ( "bytes" + "errors" "github.com/hashicorp/go-msgpack/codec" + "reflect" ) +type MessagePackCodec struct { + msgPack *codec.MsgpackHandle +} + +func InitMessagePack() MessagePackCodec { + return MessagePackCodec{ + msgPack: &codec.MsgpackHandle{}, + } +} +func (m MessagePackCodec) Encode(msg interface{}) ([]byte, error) { + m.msgPack.RawToString = true + m.msgPack.WriteExt = true + m.msgPack.MapType = reflect.TypeOf(map[string]interface{}(nil)) + + var b []byte + enc := codec.NewEncoderBytes(&b, m.msgPack) + err := enc.Encode(msg) + if err != nil { + return nil, err + } + return b, nil +} +func (m MessagePackCodec) Decode(in []byte, out interface{}) error { + dev := codec.NewDecoderBytes(in, m.msgPack) // Create a new decoder with the buffer and MessagePack handle + + return dev.Decode(out) // Decode the byte slice into the provided output structure +} +func (m MessagePackCodec) AddExtension( + t reflect.Type, + id byte, + encoder func(reflect.Value) ([]byte, error), + decoder func(reflect.Value, []byte) error) error { + return m.msgPack.AddExt(t, id, encoder, decoder) +} + // EncodeMessagePack is a function that encodes a given message using MessagePack serialization. // It takes an interface{} parameter representing the message and returns the encoded byte slice and an error. func EncodeMessagePack(msg interface{}) ([]byte, error) { @@ -30,3 +67,26 @@ func DecodeMessagePack(in []byte, out interface{}) error { return dev.Decode(out) // Decode the byte slice into the provided output structure } + +func EncodeString(s string) ([]byte, error) { + if len(s) > 0x7F { + return nil, errors.New("invalid string length") + } + b := make([]byte, len(s)+1) + b[0] = byte(len(s)) + copy(b[1:], s) + return b, nil +} + +func DecodeString(b []byte) (int, string, error) { + if len(b) == 0 { + return 0, "", errors.New("invalid length") + } + l := int(b[0]) + if len(b) < (l + 1) { + return 0, "", errors.New("invalid length") + } + s := make([]byte, l) + copy(s, b[1:l+1]) + return l + 1, string(s), nil +} diff --git a/structure/zset.go b/structure/zset.go new file mode 100644 index 00000000..13cb0e2d --- /dev/null +++ b/structure/zset.go @@ -0,0 +1,169 @@ +package structure + +import ( + "bytes" + "container/heap" + "encoding/binary" + "errors" + "github.com/ByteStorage/FlyDB/config" + "github.com/ByteStorage/FlyDB/engine" + _const "github.com/ByteStorage/FlyDB/lib/const" + "github.com/ByteStorage/FlyDB/lib/encoding" + "reflect" +) + +// ZSetStructure is a structure for ZSet or SortedSet +type ZSetStructure struct { + db *engine.DB +} +type ZSetNodes []*ZSetNode // implements heap.Interface and holds ZSetNode. + +type ZSetNode struct { + Value string // The value of the item; arbitrary. + Priority int // The priority of the item in the queue. + Index int // The index of the item in the heap. +} + +func NewZSetStructure(options config.Options) (*ZSetStructure, error) { + db, err := engine.NewDB(options) + if err != nil { + return nil, err + } + + return &ZSetStructure{db: db}, nil +} +func (zs *ZSetStructure) ZAdd(key string, score int, value string) error { + if len(key) == 0 { + return _const.ErrKeyIsEmpty + } + keyBytes := stringToBytesWithKey(key) + _, err := zs.getZSetFromDB(keyBytes) + if err != nil { + return err + } + return nil +} + +func (pq ZSetNodes) Len() int { return len(pq) } + +func (pq ZSetNodes) Less(i, j int) bool { + // We want Pop to give us the highest, not lowest, priority so we use greater than here. + return pq[i].Priority > pq[j].Priority +} + +func (pq ZSetNodes) Swap(i, j int) { + pq[i], pq[j] = pq[j], pq[i] + pq[i].Index = i + pq[j].Index = j +} + +func (pq *ZSetNodes) Push(x any) { + n := len(*pq) + item := x.(*ZSetNode) + item.Index = n + *pq = append(*pq, item) + heap.Fix(pq, n) +} + +func (pq *ZSetNodes) Pop() any { + old := *pq + n := len(old) + item := old[n-1] + old[n-1] = nil // avoid memory leak + item.Index = -1 // for safety + *pq = old[0 : n-1] + return item +} + +// update modifies the priority and value of an Item in the queue. +func (pq *ZSetNodes) update(item *ZSetNode, value string, priority int) { + item.Value = value + item.Priority = priority + heap.Fix(pq, item.Index) +} +func (pq *ZSetNodes) Bytes() ([]byte, error) { + msgPack := encoding.InitMessagePack() + + err := msgPack.AddExtension(reflect.TypeOf(ZSetNode{}), 1, zSetNodesEncoder, zSetNodesDecoder) + if err != nil { + return nil, err + } + return msgPack.Encode(pq) +} +func (pq *ZSetNodes) FromBytes(bytes []byte) error { + msgPack := encoding.InitMessagePack() + err := msgPack.AddExtension(reflect.TypeOf(ZSetNode{}), 1, nil, zSetNodesDecoder) + if err != nil { + return err + } + return msgPack.Decode(bytes, pq) +} + +func (l *ZSetStructure) getZSetFromDB(key []byte) (*ZSetNodes, error) { + // Get data corresponding to the key from the database + dbData, err := l.db.Get(key) + + // Since the key might not exist, we need to handle ErrKeyNotFound separately as it is a valid case + if err != nil && err != _const.ErrKeyNotFound { + return nil, err + } + var zSetValue ZSetNodes + // Deserialize the data into a list + err = encoding.DecodeMessagePack(dbData, zSetValue) + if err != nil { + return nil, err + } + return &zSetValue, nil +} +func (l *ZSetStructure) setZSetToDB(key []byte, zSetValue ZSetNodes) error { + // Deserialize the data into a list + val, err := encoding.EncodeMessagePack(zSetValue) + if err != nil { + return err + } + err = l.db.Put(key, val) + if err != nil { + return err + } + return nil +} + +func zSetNodesDecoder(value reflect.Value, i []byte) error { + bs := ZSetNode{} + var bytesRead int + num, s, err := encoding.DecodeString(i) + if err != nil { + return err + } + bytesRead += num + bs.Value = s + val, num := binary.Varint(i[bytesRead:]) + bytesRead += num + bs.Index = int(val) + val, num = binary.Varint(i[bytesRead:]) + bytesRead += num + bs.Priority = int(val) + value.Set(reflect.ValueOf(bs)) + return nil +} +func zSetNodesEncoder(value reflect.Value) ([]byte, error) { + zsn := value.Interface().(ZSetNode) + if zsn.Value == "" { + return nil, errors.New("empty zset") + } + buf := bytes.NewBuffer(nil) + es, err := encoding.EncodeString(zsn.Value) + if err != nil { + return nil, err + } + _, err = buf.Write(es) + if err != nil { + return nil, err + } + b := make([]byte, binary.MaxVarintLen64) + written := 0 + written += binary.PutVarint(b[:], int64(zsn.Index)) + written += binary.PutVarint(b[written:], int64(zsn.Priority)) + buf.Write(b[:written]) + return buf.Bytes(), nil +} diff --git a/structure/zset_test.go b/structure/zset_test.go new file mode 100644 index 00000000..e8f0f025 --- /dev/null +++ b/structure/zset_test.go @@ -0,0 +1,61 @@ +package structure + +import ( + "container/heap" + "github.com/stretchr/testify/assert" + "reflect" + "testing" +) + +func TestSortedSet(t *testing.T) { + items := map[string]int{ + "banana": 3, "apple": 2, "pear": 4, + } + // Create a priority queue, put the items in it, and + // establish the priority queue (heap) invariants. + pq := ZSetNodes{} + pq = make([]*ZSetNode, len(items)) + i := 0 + for value, priority := range items { + pq[i] = &ZSetNode{ + Value: value, + Priority: priority, + Index: i, + } + i++ + } + + heap.Init(&pq) + pq.Push(&ZSetNode{"Pineapple", 50, 0}) + //heap.Fix(&pq, len(pq)-1) + //pq.update(pq[0], pq[0].value, 0) + + t.Log(pq) +} + +func TestSortedSet_Bytes(t *testing.T) { + items := map[string]int{ + "banana": 3, "apple": 2, "pear": 4, + } + // Create a priority queue, put the items in it, and + // establish the priority queue (heap) invariants. + pq := ZSetNodes{} + pq = make([]*ZSetNode, len(items)) + i := 0 + for value, priority := range items { + pq[i] = &ZSetNode{ + Value: value, + Priority: priority, + Index: i, + } + i++ + } + + heap.Init(&pq) + b, err := pq.Bytes() + assert.NoError(t, err) + rb := ZSetNodes{} + err = rb.FromBytes(b) + assert.NoError(t, err) + assert.True(t, reflect.DeepEqual(rb, pq)) +} From 6ece20582db2097780a0fb3a4c96f8d6bbc1e49e Mon Sep 17 00:00:00 2001 From: Saeid Aghapour Date: Mon, 17 Jul 2023 21:17:54 +0330 Subject: [PATCH 2/5] implements stable store component interfaces(#179) --- lib/encoding/messagepack.go | 152 ++++- lib/encoding/messagepack_test.go | 81 +++ structure/zset.go | 994 ++++++++++++++++++++++++++++--- structure/zset_test.go | 225 +++++-- 4 files changed, 1285 insertions(+), 167 deletions(-) diff --git a/lib/encoding/messagepack.go b/lib/encoding/messagepack.go index 11a23321..075ba371 100644 --- a/lib/encoding/messagepack.go +++ b/lib/encoding/messagepack.go @@ -7,86 +7,178 @@ import ( "reflect" ) +// MessagePackCodec struct, holds references to MessagePack handler and byte slice, +// along with Encoder and Decoder, and a typeMap for storing reflect.Type type MessagePackCodec struct { - msgPack *codec.MsgpackHandle + MsgPack *codec.MsgpackHandle + b *[]byte + enc *codec.Encoder + dec *codec.Decoder } +// MessagePackCodecEncoder struct derives from MessagePackCodec +// it manages IDs and counts of the encoded objects. +type MessagePackCodecEncoder struct { + MessagePackCodec // Embedded MessagePackCodec + + // nextId is used probably for tracking ID of the next object to encode. + nextId uint + + // objects represents the count of objects that have been encoded. + objects int +} + +// MessagePackCodecDecoder struct, holds a reference to a MessagePackCodec instance. +type MessagePackCodecDecoder struct { + MessagePackCodec +} + +// InitMessagePack function initializes MessagePackCodec struct and returns it. func InitMessagePack() MessagePackCodec { return MessagePackCodec{ - msgPack: &codec.MsgpackHandle{}, + MsgPack: &codec.MsgpackHandle{}, + } +} + +// NewMessagePackEncoder function creates new MessagePackCodecEncoder and initializes it. +func NewMessagePackEncoder() *MessagePackCodecEncoder { + msgPack := &codec.MsgpackHandle{} + b := make([]byte, 0) + return &MessagePackCodecEncoder{ + MessagePackCodec: MessagePackCodec{ + MsgPack: &codec.MsgpackHandle{}, + b: &b, + enc: codec.NewEncoderBytes(&b, msgPack), + }, } } -func (m MessagePackCodec) Encode(msg interface{}) ([]byte, error) { - m.msgPack.RawToString = true - m.msgPack.WriteExt = true - m.msgPack.MapType = reflect.TypeOf(map[string]interface{}(nil)) +// NewMessagePackDecoder function takes in a byte slice, and returns a pointer to newly created +// and initialized MessagePackCodecDecoder +func NewMessagePackDecoder(b []byte) *MessagePackCodecDecoder { + msgPack := &codec.MsgpackHandle{} + return &MessagePackCodecDecoder{ + MessagePackCodec: MessagePackCodec{ + MsgPack: &codec.MsgpackHandle{}, + b: &b, + dec: codec.NewDecoderBytes(b, msgPack), + }, + } +} + +// Encode method for MessagePackCodec. It encodes the input value into a byte slice using MessagePack. +// Returns encoded byte slice or error. +func (m *MessagePackCodec) Encode(msg interface{}) ([]byte, error) { var b []byte - enc := codec.NewEncoderBytes(&b, m.msgPack) - err := enc.Encode(msg) + err := codec.NewEncoderBytes(&b, m.MsgPack).Encode(msg) if err != nil { return nil, err } return b, nil } -func (m MessagePackCodec) Decode(in []byte, out interface{}) error { - dev := codec.NewDecoderBytes(in, m.msgPack) // Create a new decoder with the buffer and MessagePack handle - return dev.Decode(out) // Decode the byte slice into the provided output structure +// Encode is a method for MessagePackCodecEncoder. +// It takes in msg of type interface{} as input, that is to be encoded. +// Returns an error if encountered during encoding. +func (m *MessagePackCodecEncoder) Encode(msg interface{}) error { + return m.enc.Encode(msg) +} + +// Bytes is a method for MessagePackCodecEncoder. +// It returns a byte slice pointer b. +func (m *MessagePackCodecEncoder) Bytes() []byte { + return *m.b +} + +// Decode is a method on MessagePackCodecDecoder that decodes MessagePack data +// into the provided interface; returns an error if any decoding issues occur. +func (m *MessagePackCodecDecoder) Decode(msg interface{}) error { + if m.dec == nil { + return errors.New("decoder not initialized") + } + return m.dec.Decode(msg) } -func (m MessagePackCodec) AddExtension( + +// Decode on MessagePackCodec type, using a byte slice as input. +func (m *MessagePackCodec) Decode(in []byte, out interface{}) error { + // Create new decoder using the byte slice and MessagePack handle. + dec := codec.NewDecoderBytes(in, m.MsgPack) + + // Attempt to decode the byte slice into the desired output structure. + return dec.Decode(out) +} + +// AddExtension method allows for setting custom encoders/decoders for specific reflect.Types. +func (m *MessagePackCodec) AddExtension( t reflect.Type, id byte, encoder func(reflect.Value) ([]byte, error), decoder func(reflect.Value, []byte) error) error { - return m.msgPack.AddExt(t, id, encoder, decoder) + + return m.MsgPack.AddExt(t, id, encoder, decoder) } -// EncodeMessagePack is a function that encodes a given message using MessagePack serialization. -// It takes an interface{} parameter representing the message and returns the encoded byte slice and an error. +// EncodeMessagePack function encodes a given object into MessagePack format. func EncodeMessagePack(msg interface{}) ([]byte, error) { - var b []byte - var mph codec.MsgpackHandle - h := &mph - enc := codec.NewEncoderBytes(&b, h) // Create a new encoder with the provided message and MessagePack handle + // Directly initialize the byte slice and encoder. + b := make([]byte, 0) + enc := codec.NewEncoderBytes(&b, &codec.MsgpackHandle{}) - err := enc.Encode(msg) // Encode the message using the encoder - if err != nil { + // Attempt to encode the message. + if err := enc.Encode(msg); err != nil { return nil, err } - return b, nil // Return the encoded byte slice + + // Return the encoded byte slice. + return b, nil } -// DecodeMessagePack is a function that decodes a given byte slice using MessagePack deserialization. -// It takes an input byte slice and an interface{} representing the output structure for the deserialized message. -// It returns an error if the decoding process fails. +// DecodeMessagePack function decodes a byte slice of MessagePack data into a given object. func DecodeMessagePack(in []byte, out interface{}) error { - buf := bytes.NewBuffer(in) // Create a new buffer with the input byte slice - mph := codec.MsgpackHandle{} - dev := codec.NewDecoder(buf, &mph) // Create a new decoder with the buffer and MessagePack handle - - return dev.Decode(out) // Decode the byte slice into the provided output structure + dec := codec.NewDecoder(bytes.NewBuffer(in), &codec.MsgpackHandle{}) + return dec.Decode(out) } +// EncodeString Functions for encoding and decoding strings to and from byte slices. func EncodeString(s string) ([]byte, error) { + // Check if string length is within correct bounds. if len(s) > 0x7F { return nil, errors.New("invalid string length") } + + // Create a byte slice of appropriate length. b := make([]byte, len(s)+1) b[0] = byte(len(s)) + + // Copy the string into the byte slice. copy(b[1:], s) + + // Return the byte slice. return b, nil } +// DecodeString is a function that takes an input byte slice and attempts to decode it to obtain a string. +// Return parameters are an integer, a string and an error. Integer denotes the length of the byte slice +// representation of the string including length-field. The second return parameter is the decoded string. +// DecodeString raises an error if the length of byte slice is less than the expected string length plus +// one (considering the string length field) or if the provided byte slice is empty. +// If successful, returns length of byte representation of string, the decoded string and a nil error. func DecodeString(b []byte) (int, string, error) { + // Check that byte slice is not empty. if len(b) == 0 { return 0, "", errors.New("invalid length") } + + // Determine the length of the string. l := int(b[0]) if len(b) < (l + 1) { return 0, "", errors.New("invalid length") } + + // Create a byte slice of the appropriate length and copy the string into it. s := make([]byte, l) copy(s, b[1:l+1]) + + // Return the length of the string and the string itself. return l + 1, string(s), nil } diff --git a/lib/encoding/messagepack_test.go b/lib/encoding/messagepack_test.go index 68b12adf..1fb39808 100644 --- a/lib/encoding/messagepack_test.go +++ b/lib/encoding/messagepack_test.go @@ -1,6 +1,8 @@ package encoding import ( + "bytes" + "github.com/hashicorp/go-msgpack/codec" "github.com/hashicorp/raft" "github.com/stretchr/testify/assert" "testing" @@ -16,3 +18,82 @@ func TestEncodeMessagePack(t *testing.T) { assert.NoError(t, err) assert.Equal(t, decData, *data) } +func TestInitMessagePack(t *testing.T) { + msgPack := InitMessagePack() + + assert.NotNil(t, msgPack.MsgPack) +} + +func TestNewMessagePackEncoder(t *testing.T) { + encoder := NewMessagePackEncoder() + + assert.NotNil(t, encoder.enc) + assert.NotNil(t, encoder.b) +} + +func TestNewMessagePackDecoder(t *testing.T) { + exampleBytes := []byte("example") + decoder := NewMessagePackDecoder(exampleBytes) + + assert.NotNil(t, decoder.b) + assert.NotNil(t, decoder.dec) +} + +func TestEncode(t *testing.T) { + msg := "example message" + msgPack := InitMessagePack() + + encoded, err := msgPack.Encode(msg) + + assert.NotNil(t, encoded) + assert.Nil(t, err) +} + +func TestMsgPackDecoder_Decode(t *testing.T) { + msg := "example message" + msgPack := InitMessagePack() + encoded, _ := msgPack.Encode(msg) + decoder := NewMessagePackDecoder(encoded) + + var decoded string + err := decoder.Decode(&decoded) + + assert.Nil(t, err) + assert.Equal(t, msg, decoded) +} + +func TestMsgPackDecoder_Decode_ErrDecoderNotInitialized(t *testing.T) { + msgPack := InitMessagePack() + encoded, _ := msgPack.Encode("example message") + decoder := &MessagePackCodecDecoder{ + MessagePackCodec: MessagePackCodec{ + MsgPack: &codec.MsgpackHandle{}, + b: &encoded, + }, + } + + var decoded string + err := decoder.Decode(&decoded) + + assert.NotNil(t, err) + assert.Equal(t, "decoder not initialized", err.Error()) +} + +func TestEncodeString(t *testing.T) { + s := "example string" + + encoded, err := EncodeString(s) + + assert.NotNil(t, encoded) + assert.Nil(t, err) +} + +func TestEncodeString_ErrStringLength(t *testing.T) { + s := bytes.Repeat([]byte("a"), 0x80) // 128-byte long string + + encoded, err := EncodeString(string(s)) + + assert.Nil(t, encoded) + assert.NotNil(t, err) + assert.Equal(t, "invalid string length", err.Error()) +} diff --git a/structure/zset.go b/structure/zset.go index 13cb0e2d..236eae33 100644 --- a/structure/zset.go +++ b/structure/zset.go @@ -1,169 +1,975 @@ package structure import ( - "bytes" - "container/heap" - "encoding/binary" "errors" + "fmt" "github.com/ByteStorage/FlyDB/config" + "math" + "math/rand" + "github.com/ByteStorage/FlyDB/engine" _const "github.com/ByteStorage/FlyDB/lib/const" "github.com/ByteStorage/FlyDB/lib/encoding" - "reflect" ) +const ( + // SKIPLIST_MAX_LEVEL is better to be log(n) for the best performance. + SKIPLIST_MAX_LEVEL = 10 // + SKIPLIST_PROB = 0.25 // SkipList Probability +) + +/** +ZSet or Sorted Set structure is borrowed from Redis' implementation, the Redis implementation +utilizes a SkipList and a dictionary. +*/ + // ZSetStructure is a structure for ZSet or SortedSet type ZSetStructure struct { db *engine.DB } -type ZSetNodes []*ZSetNode // implements heap.Interface and holds ZSetNode. -type ZSetNode struct { - Value string // The value of the item; arbitrary. - Priority int // The priority of the item in the queue. - Index int // The index of the item in the heap. +// ZSetNodes represents a specific data structure in the database, which is key to handling sorted sets (ZSets). +// This struct facilitates interactions with data stored in the sorted set, allowing for both complex and simple operations. +// +// It contains three struct fields: +// +// - 'dict': A Go map with string keys and pointers to SkipListNodeValue values. This map aims to provide quick access to +// individual values in the sorted set based on the provided key. +// +// - 'size': An integer value representing the current size (number of elements) in the ZSetNodes struct. This information is efficiently +// kept track of whenever elements are added or removed from the set, so no separate computation is needed to retrieve this information. +// +// - 'skipList': A pointer towards a SkipList struct. SkipLists perform well under numerous operations, such as insertion, deletion, and searching. They are +// a crucial component in maintaining the sorted set in a practical manner. In this context, the SkipList is used to keep an ordered track of the elements +// in the ZSetNodes struct. +type ZSetNodes struct { + // dict field is a map where the key is a string and + // the value is a pointer to SkipListNodeValue instances, + // codified with the tag "dict". + dict map[string]*SkipListNodeValue `codec:"dict"` + + // size field represents the quantity of elements within + // the structure, codified with the tag "size". + size int `codec:"size"` + + // skipList field is a pointer to an object of type SkipList, + // codified with the tag "skip_list". + skipList *SkipList `codec:"skip_list"` +} + +// SkipList represents a skip list data structure, an ordered list with a hierarchical +// structure that allows for fast search and insertion of elements. +type SkipList struct { + // level represents the highest level of the skip list. + level int + + // head refers to the first node in the skip list. + head *SkipListNode + + // tail refers to the last node in the skip list. + tail *SkipListNode + + // size represents the total number of nodes in the skip list (excluding head and tail nodes). + size int +} + +// SkipListLevel is a structure encapsulating a single level in a skip list data structure. +// It contains two struct fields: +// - 'next': A pointer to the next SkipListNode in the current level. +// - 'span': An integer representing the span size of this SkipListLevel. The span is the number of nodes between the current node +// and the node to which the next pointer is pointing in the skip list. +type SkipListLevel struct { + next *SkipListNode + span int } +// SkipListNode represents a single node in a SkipList structure. +// It is built with three elements: +// - 'prev': This is a pointer to the previous node in the skip list. Together with the 'next' pointers in the SkipListNodeLevel, +// it forms a network of nodes, where traversal of the skip list is possible both forwards and backwards. +// - 'level': This is an array (slice) of pointers towards SkipListLevel structures. Each element corresponds to a level of the skip list, +// embedding the 'next' node at that same level, and the span between the current node and that 'next' node. +// - 'value': This is a pointer towards a single SkipListNodeValue structure. It holds the actual payload of the node +// (namely the 'score', 'key', and 'value' properties used in the context of Redis Sorted Sets), as well as provides the basis for ordering of nodes in the skip list. +type SkipListNode struct { + // prev is a pointer to the previous node in the skip list. + prev *SkipListNode + + // level is a slice of pointers to SkipListLevel. + // Each level represents a forward pointer to the next node in the current list level. + level []*SkipListLevel + + // value is a pointer to the SkipListNodeValue. + // This represents the value that this node holds. + value *SkipListNodeValue +} + +// SkipListNodeValue is a struct used in the SkipList data structure. In the context of Redis Sorted Set (ZSet) implementation, +// it represents a single node value in the skip list. A SkipListNodeValue has three members: +// - 'score' which is an integer representing the score of the node. Nodes in a skip list are ordered by this score in ascending order. +// - 'member' which is a string defining the key of the node. For nodes with equal scores, order is determined with lexicographical comparison of keys. +// - 'value' which is an interface{}, meaning it can hold any data type. This represents the actual value of the node in the skip list. +type SkipListNodeValue struct { + // Score is typically used for sorting purposes. Nodes with higher scores will be placed higher in the skip list. + score int + + // member represents the unique identifier for each node. + member string + + // value is the actual content/data that is being stored in the node. + value interface{} +} + +// randomLevel is a function that generates a probabilistic level for a node in a SkipList data structure. +// The goal is to diversify the level distribution and contribute to achieving an ideal skiplist performance. +// Function has no parameters. +// The process starts with two initial variables: +// - 'level' which starts from 1, +// - 'thresh' which is a product of the constant skiplist probability 'SKIPLIST_PROB' and bitwise mask: 0xFFF, taken to the nearest integer. +// +// In an infinite loop, a random 31-bit integer value is generated, bitwise-and is computed with 0xFFF and compared with 'thresh'. +// If the result is smaller, 'level' is incremented by one. Otherwise, the loop is exited. +// Finally, the function checks the calculated level against the maximum allowed skiplist level 'SKIPLIST_MAX_LEVEL'. +// If 'level' is greater, 'SKIPLIST_MAX_LEVEL' is returned, otherwise the calculated 'level' value is returned. +// The function returns an integer which will be the level of new node in skiplist. +func randomLevel() int { + // Initialize level to 1 + level := 1 + + // Calculate the threshold for level. It's derived from the probability constant of the skip list. + thresh := int(math.Round(SKIPLIST_PROB * 0xFFF)) + + // While a randomly generated number is less than this threshold, increment the level. + for int(rand.Int31()&0xFFF) < thresh { + level++ + } + + // Check if the level is more than the maximum allowed level for the skip list + // If it is, return the maximum level. Otherwise, return the generated level. + if level > SKIPLIST_MAX_LEVEL { + return SKIPLIST_MAX_LEVEL + } else { + return level + } +} + +// NewZSetStructure Returns a new ZSetStructure func NewZSetStructure(options config.Options) (*ZSetStructure, error) { db, err := engine.NewDB(options) if err != nil { return nil, err } - return &ZSetStructure{db: db}, nil } -func (zs *ZSetStructure) ZAdd(key string, score int, value string) error { + +// newZSetNodes is a function that creates a new ZSetNodes object and returns a pointer to it. +// It initializes the dictionary member dict of the newly created object to an empty map. +// The map is intended to map strings to pointers of SkipListNodeValue objects. +// size member of the object is set to 0, indicating that the ZSetNodes object is currently empty. +// The skipList member of the object is set to a new SkipList object created by calling `newSkipList()` function. +func newZSetNodes() *ZSetNodes { + return &ZSetNodes{ + dict: make(map[string]*SkipListNodeValue), + size: 0, + skipList: newSkipList(), + } +} + +// newSkipList is a function that creates an instance of a SkipList struct object and returns a pointer to it. +// This involves initializing the level of the SkipList to 1 and creating a new SkipListNode object as the head of the list. +// The head node is constructed with a level set to SKIPLIST_MAX_LEVEL, key and value as empty string and value as nil respectively. +func newSkipList() *SkipList { + return &SkipList{ + level: 1, + head: newSkipListNode(SKIPLIST_MAX_LEVEL, 0, "", nil), + } +} + +// newSkipListNode is a function that takes integer as level, score and a string as key along with a value of any type. +// It returns a pointer to a SkipListNode. This function is responsible for creating a new SkipListNode with provided level, score, +// key, and value. After creating the node, it initializes every level of the node with an empty SkipListLevel object. +// In the context of a skip list data structure, this function serves as a helper function for creating new nodes to be inserted to the list. +func newSkipListNode(level int, score int, key string, value interface{}) *SkipListNode { + // Create a new SkipListNode with specified score, key, value and a slice of + // SkipListLevel with length equal to specified level + node := &SkipListNode{ + value: newSkipListNodeValue(score, key, value), + level: make([]*SkipListLevel, level), + } + + // Initialize each SkipListLevel in the level slice + for i := range node.level { + node.level[i] = new(SkipListLevel) + } + // Returning the pointer to the created node + return node +} + +// newSkipListNodeValue is a function that constructs and returns a new SkipListNodeValue. +// It takes a score (int), a key (string), and a value (interface{}) as parameters. +// These parameters serve as the initial state of the SkipListNodeValue upon its creation. +func newSkipListNodeValue(score int, member string, value interface{}) *SkipListNodeValue { + // Create a new instance of a SkipListNodeValue with the provided score, key, and value. + node := &SkipListNodeValue{ + score: score, + member: member, + value: value, + } + + // Return the newly created SkipListNodeValue. + return node +} + +// insert is a method of the SkipList type that is used to insert a new node into the skip list. It takes as arguments +// the score (int), key (string) and a value (interface{}), and returns a pointer to the SkipListNodeValue struct. The method +// organizes nodes in the list based on the score in ascending order. If two nodes have the same score, they will be arranged +// based on the key value. The method also assigns span values to the levels in the skip list. +func (sl *SkipList) insert(score int, key string, value interface{}) *SkipListNodeValue { + update := make([]*SkipListNode, SKIPLIST_MAX_LEVEL) + rank := make([]int, SKIPLIST_MAX_LEVEL) + node := sl.head + + // Go from highest level to lowest + for i := sl.level - 1; i >= 0; i-- { + // store rank that is crossed to reach the insert position + if sl.level-1 == i { + rank[i] = 0 + } else { + rank[i] = rank[i+1] + } + if node.level[i] != nil { + for node.level[i].next != nil && + (node.level[i].next.value.score < score || + (node.level[i].next.value.score == score && // score is the same but the key is different + node.level[i].next.value.member < key)) { + rank[i] += node.level[i].span + node = node.level[i].next + } + } + update[i] = node + } + level := randomLevel() + // add a new level + if level > sl.level { + for i := sl.level; i < level; i++ { + rank[i] = 0 + update[i] = sl.head + update[i].level[i].span = sl.size + } + sl.level = level + } + node = newSkipListNode(level, score, key, value) + + for i := 0; i < level; i++ { + node.level[i].next = update[i].level[i].next + update[i].level[i].next = node + // update span covered by update[i] as newNode is inserted here + node.level[i].span = update[i].level[i].span - (rank[0] - rank[i]) + update[i].level[i].span = (rank[0] - rank[i]) + 1 + } + // increment span for untouched levels + for i := level; i < sl.level; i++ { + update[i].level[i].span++ + } + // update info + if update[0] == sl.head { + node.prev = nil + } else { + node.prev = update[0] + } + if node.level[0].next != nil { + node.level[0].next.prev = node + } else { + sl.tail = node + } + sl.size++ + return node.value +} + +// SkipList is a data structure that allows fast search, insertion, and removal operations. +// Here we define a method delete on it. +// +// The delete method in the skip list will remove nodes that have a given score and key from the skip list. +// If no such nodes are found, the function does nothing. +// +// Parameters: +// +// score: the score of the node to delete. +// key: the key of the node to delete. +func (sl *SkipList) delete(score int, member string) { + + // update: an array of pointers to SkipListNodes; holds the nodes that will have their next pointers updated. + update := make([]*SkipListNode, SKIPLIST_MAX_LEVEL) + + // node: start from the head of our SkipList sl + node := sl.head + + // The code block of "for" loop populates the "update" variable with nodes which reference will change + // due to the removal of the target node. + for i := sl.level; i >= 0; i-- { + // This loop is traversing the SkipList horizontally until it finds a node with a score greater + // than or equal to our target score or if the scores are equal it also checks the member. + for node.level[i].next != nil && + (node.level[i].next.value.score < score || + (node.level[i].next.value.score == score && + node.level[i].next.value.member < member)) { + node = node.level[i].next + } + update[i] = node + } + + // After the traversal, we set the node to point to the possibly (to be) deleted node. + node = node.level[0].next + + // If the possibly deleted node is the target node (it has the same score and member), then remove it. + if node != nil && node.value.score == score && node.value.member == member { + sl.deleteNode(node, update) + } +} +func (sl *SkipList) getRange(start int, end int, reverse bool) (nv []SkipListNodeValue) { + if end > sl.size { + end = sl.size - 1 + } + if start > end { + return + } + if end < 0 { + return nil // todo unexpected behavior, we can set it to zero as well + } + node := sl.head + if reverse { + node = sl.getNodeByRank(end) + } else { + node = sl.getNodeByRank(start) + } + if reverse { + node = sl.getNodeByRank(end) + } else { + node = sl.getNodeByRank(start) + } + for i := start; i < end; i++ { + if reverse { + nv = append(nv, *node.value) + node = node.prev + } else { + nv = append(nv, *node.value) + node = node.level[0].next + } + } + return nv +} + +// deleteNode is a method linked to the SkipList struct that allows to remove nodes from the SkipList instance. +// It takes two parameters: a pointer to the node to be deleted, and a slice of pointers to SkipListNode which are required for node updates. +// deleteNode performs the deletion through a two-step process: +// - First, it loops over every level in the SkipList, updating level spans and next node pointers accordingly. +// - Then, it sets the pointers back to the previous node in the data structure and updates the tail and level of the whole list. +// Finally, it decreases the size of the list by one, as a node is being removed from it. +// It doesn't return any value and modifies the SkipList directly. + +func (sl *SkipList) deleteNode(node *SkipListNode, updates []*SkipListNode) { + for i := 0; i < sl.level; i++ { + if updates[i].level[i].next == node { + updates[i].level[i].span += node.level[i].span - 1 + updates[i].level[i].next = node.level[i].next + } else { + updates[i].level[i].span-- + } + } + //update backwards + if node.level[0].next != nil { + node.level[0].next.prev = node.prev + } else { + sl.tail = node.prev + } + + for sl.level > 1 && sl.head.level[sl.level-1].next == nil { + sl.level-- + } + sl.size-- +} + +// getRank method receives a SkipList pointer and two parameters: an integer 'score' and a string 'key'. +// It then calculates the rank of an element in the SkipList. The rank is determined based on two conditions: +// - the score of the next node is less than the provided score +// - or, the score of the next node equal to the provided score and the key of the next node is less than or equal to the provided key. +// +// Parameters: +// sl: A pointer to the SkipList object. +// score: The score that we are comparing with the scores in the skiplist. +// key: The key that we are comparing with the keys in the skiplist. +// +// Return: +// Returns the rank of the element in the SkipList if it's found, otherwise returns 0. +func (sl *SkipList) getRank(score int, key string) int { + var rank int + h := sl.head // Start at the head node of the SkipList + + // For loop starts from the top level and goes down to the level 0 + for i := sl.level; i >= 0; i-- { + // While loop advances the 'h' pointer as long as the next node exists and the conditions are fulfilled + for h.level[i].next != nil && + (h.level[i].next.value.score < score || + (h.level[i].next.value.score == score && + h.level[i].next.value.member <= key)) { + + // Increase the rank by the span of the current level + rank += h.level[i].span + // Move to the next node + h = h.level[i].next + } + // If the key of the current node is equal to the provided key, return the rank + if h.value.member == key { + return rank + } + } + // If the element is not found in the SkipList, return 0 + return 0 +} + +// getNodeByRank is a method of the SkipList type that is used to retrieve a node based on its rank within the list. +// The method takes as argument an integer rank and returns a pointer to the SkipListNode at the specified rank, +// or nil if there is no such node. +// +// First, the method initializes a variable traversed to store the cumulative span of nodes traversed thus far in the search. +// It sets a helper variable h to the head of the SkipList, to begin the traversal. +// +// The method then enter a loop that iterates through the levels of the SkipList from the highest down to the base level. +// On each level, while the next node exists and the total span traversed plus the span of the next node doesn't exceed the target rank, +// the method moves to the next node and adds its span to the cumulative span traversed. +// +// If during the traversal the cumulative span equals the target rank, the method returns the getNodeByRank node. +// If the end of the SkipList is reached, or the target rank isn't found on any level, the method returns nil. +func (sl *SkipList) getNodeByRank(rank int) *SkipListNode { + // This variable is used to keep track of the number of nodes we have + // traversed while going through the levels of the SkipList. + var traversed int + + // Define a SkipListNode pointer h, initialized with sl.head + // At the start, this pointer is set to head node of the SkipList. + h := sl.head + + // The outer loop decrements levels from highest level to lowest. + for i := sl.level - 1; i >= 0; i-- { + + // The inner loop traverses the nodes at current level while the next node isn't null and we haven't traversed beyond the 'rank'. + // The traversed variable is also updated to include the span of the current level. + for h.level[i].next != nil && (traversed+h.level[i].span) <= rank { + traversed += h.level[i].span + h = h.level[i].next + } + + // If traversed equals 'rank', it means we've found the node at the rank we are looking for. + // So, return the node. + if traversed == rank { + return h + } + } + + // If the node at 'rank' wasn't found in the SkipList, return nil. + return nil +} + +// ZAdd adds a value with its given score and member to a sorted set (ZSet), associated with +// the provided key. It is a method on the ZSetStructure type. +// +// Parameters: +// +// key: a string that represents the key of the sorted set. +// score: an integer value that determines the order of the added element in the sorted set. +// member: a string used for identifying the added value within the sorted set. +// value: the actual value to be stored within the sorted set. +// +// If the key is an empty string, an error will be returned +func (zs *ZSetStructure) ZAdd(key string, score int, member string, value string) error { + if len(key) == 0 { + return _const.ErrKeyIsEmpty + } + + zSet, err := zs.getOrCreateZSet(key) + + if err != nil { + return fmt.Errorf("failed to get or create ZSet from DB with key '%v': %w", key, err) + } + + // if values didn't change, do nothing + if zs.valuesDidntChange(zSet, score, member, value) { + return nil + } + + if err := zs.updateZSet(zSet, key, score, member, value); err != nil { + return fmt.Errorf("failed to set ZSet to DB with key '%v': %w", key, err) + } + + return nil +} + +/* +ZRem is a method belonging to ZSetStructure that removes a member from a ZSet. + +Parameters: + - key (string): The key of the ZSet. + - member (string): The member to be removed. + +Returns: + - error: An error if the operation fails. + +The ZRem method checks for a non-empty key, retrieves the corresponding ZSet +from the database, removes the specified member, and then updates +the ZSet in the database. If any point of this operation fails, +the function will return the corresponding error. +*/ +func (zs *ZSetStructure) ZRem(key string, member string) error { if len(key) == 0 { return _const.ErrKeyIsEmpty } keyBytes := stringToBytesWithKey(key) - _, err := zs.getZSetFromDB(keyBytes) + + zSet, err := zs.getZSetFromDB(keyBytes) + if err != nil { + return fmt.Errorf("failed to get or create ZSet from DB with key '%v': %w", key, err) + } + if err = zSet.RemoveNode(member); err != nil { return err } - return nil + return zs.setZSetToDB(keyBytes, zSet) } -func (pq ZSetNodes) Len() int { return len(pq) } +// ZScore method retrieves the score associated with the member in a sorted set stored at the key +func (zs *ZSetStructure) ZScore(key string, member string) (int, error) { + if len(key) == 0 { + return 0, _const.ErrKeyIsEmpty + } + keyBytes := stringToBytesWithKey(key) -func (pq ZSetNodes) Less(i, j int) bool { - // We want Pop to give us the highest, not lowest, priority so we use greater than here. - return pq[i].Priority > pq[j].Priority + zSet, err := zs.getZSetFromDB(keyBytes) + if err != nil { + return 0, fmt.Errorf("failed to get or create ZSet from DB with key '%v': %w", key, err) + } + // if the member in the sorted set is found, return the score associated with it + if v, ok := zSet.dict[member]; ok { + return v.score, nil + } + + // if the member doesn't exist in the set, return score of zero and an error + return 0, _const.ErrKeyNotFound } -func (pq ZSetNodes) Swap(i, j int) { - pq[i], pq[j] = pq[j], pq[i] - pq[i].Index = i - pq[j].Index = j +/* +ZRank is a method belonging to the ZSetStructure type. This method retrieves the rank of an element within a sorted set identified by a key. The rank is an integer corresponding to the element's 0-based position in the sorted set when it is arranged in ascending order. + +Parameters: +key (string): The key that identifies the sorted set. +member (string): The element for which you want to find the rank. + +Returns: +int: An integer indicating the rank of the member in the set. + + Rank zero means the member is not found in the set. + +error: If an error occurs, an error object will be returned. + + Possible errors include: + - key is empty + - failure to get or create the ZSet from the DB + - the provided key does not exist in the DB + +Example: +rank, err := zs.ZRank("myKey", "memberName") + + if err != nil { + log.Fatal(err) + } + +fmt.Printf("The rank of '%s' in the set '%s' is %d\n", "memberName", "myKey", rank) +*/ +func (zs *ZSetStructure) ZRank(key string, member string) (int, error) { + if len(key) == 0 { + return 0, _const.ErrKeyIsEmpty + } + keyBytes := stringToBytesWithKey(key) + + zSet, err := zs.getZSetFromDB(keyBytes) + if err != nil { + return 0, fmt.Errorf("failed to get or create ZSet from DB with key '%v': %w", key, err) + } + if v, ok := zSet.dict[member]; ok { + return zSet.skipList.getRank(v.score, member), nil + } + + // rank zero means no rank found + return 0, _const.ErrKeyNotFound } +func (zs *ZSetStructure) ZRevRank(key string, member string) (int, error) { + if len(key) == 0 { + return 0, _const.ErrKeyIsEmpty + } + keyBytes := stringToBytesWithKey(key) + + zSet, err := zs.getZSetFromDB(keyBytes) + if err != nil { + return 0, fmt.Errorf("failed to get or create ZSet from DB with key '%v': %w", key, err) + } + if v, ok := zSet.dict[member]; ok { + rank := zSet.skipList.getRank(v.score, member) + return zSet.size - rank, nil + } -func (pq *ZSetNodes) Push(x any) { - n := len(*pq) - item := x.(*ZSetNode) - item.Index = n - *pq = append(*pq, item) - heap.Fix(pq, n) + // rank zero means no rank found + return 0, _const.ErrKeyNotFound } +func (zs *ZSetStructure) ZRange(key string, start int, end int) ([]SkipListNodeValue, error) { + if len(key) == 0 { + return nil, _const.ErrKeyIsEmpty + } + keyBytes := stringToBytesWithKey(key) -func (pq *ZSetNodes) Pop() any { - old := *pq - n := len(old) - item := old[n-1] - old[n-1] = nil // avoid memory leak - item.Index = -1 // for safety - *pq = old[0 : n-1] - return item + zSet, err := zs.getZSetFromDB(keyBytes) + if err != nil { + return nil, fmt.Errorf("failed to get or create ZSet from DB with key '%v': %w", key, err) + } + r := zSet.skipList.getRange(start, end, false) + + // rank zero means no rank found + return r, nil } +func (zs *ZSetStructure) ZRevRange(key string, start int, end int) ([]SkipListNodeValue, error) { + if len(key) == 0 { + return nil, _const.ErrKeyIsEmpty + } + keyBytes := stringToBytesWithKey(key) + + zSet, err := zs.getZSetFromDB(keyBytes) + if err != nil { + return nil, fmt.Errorf("failed to get or create ZSet from DB with key '%v': %w", key, err) + } + r := zSet.skipList.getRange(start, end, true) -// update modifies the priority and value of an Item in the queue. -func (pq *ZSetNodes) update(item *ZSetNode, value string, priority int) { - item.Value = value - item.Priority = priority - heap.Fix(pq, item.Index) + // rank zero means no rank found + return r, nil } -func (pq *ZSetNodes) Bytes() ([]byte, error) { - msgPack := encoding.InitMessagePack() +func (zs *ZSetStructure) ZCard(key string) (int, error) { + if len(key) == 0 { + return 0, _const.ErrKeyIsEmpty + } + keyBytes := stringToBytesWithKey(key) - err := msgPack.AddExtension(reflect.TypeOf(ZSetNode{}), 1, zSetNodesEncoder, zSetNodesDecoder) + zSet, err := zs.getZSetFromDB(keyBytes) if err != nil { - return nil, err + return 0, fmt.Errorf("failed to get or create ZSet from DB with key '%v': %w", key, err) } - return msgPack.Encode(pq) + // get the size of the dictionary + return zSet.size, nil } -func (pq *ZSetNodes) FromBytes(bytes []byte) error { - msgPack := encoding.InitMessagePack() - err := msgPack.AddExtension(reflect.TypeOf(ZSetNode{}), 1, nil, zSetNodesDecoder) +func (zs *ZSetStructure) ZIncrBy(key string, member string, incBy int) error { + if len(key) == 0 { + return _const.ErrKeyIsEmpty + } + keyBytes := stringToBytesWithKey(key) + + zSet, err := zs.getZSetFromDB(keyBytes) if err != nil { + return fmt.Errorf("failed to get or create ZSet from DB with key '%v': %w", key, err) + } + if v, ok := zSet.dict[member]; ok { + return zSet.InsertNode(v.score+incBy, member, v.value) + } + + return _const.ErrKeyNotFound +} + +// getOrCreateZSet attempts to retrieve a sorted set by a key, or creates a new one if it doesn't exist. +func (zs *ZSetStructure) getOrCreateZSet(key string) (*ZSetNodes, error) { + keyBytes := stringToBytesWithKey(key) + zSet, err := zs.getZSetFromDB(keyBytes) + // if key is not in the DB, create it. + if errors.Is(err, _const.ErrKeyNotFound) { + return newZSetNodes(), nil + } + + return zSet, err +} + +// valuesDidntChange checks if the data of a specific member in a sorted set remained the same. +func (zs *ZSetStructure) valuesDidntChange(zSet *ZSetNodes, score int, member string, value string) bool { + if v, ok := zSet.dict[member]; ok { + return v.score == score && v.member == member && v.value == value + } + + return false +} + +// updateZSet updates or inserts a member in a sorted set and saves the change in storage. +func (zs *ZSetStructure) updateZSet(zSet *ZSetNodes, key string, score int, member string, value string) error { + if err := zSet.InsertNode(score, member, value); err != nil { return err } - return msgPack.Decode(bytes, pq) + + return zs.setZSetToDB(stringToBytesWithKey(key), zSet) } -func (l *ZSetStructure) getZSetFromDB(key []byte) (*ZSetNodes, error) { - // Get data corresponding to the key from the database - dbData, err := l.db.Get(key) +// InsertNode is a method on the ZSetNodes structure. It inserts a new node +// or updates an existing node in the skip list and the dictionary. +// It takes three parameters: score (an integer), key (a string), +// and value (of any interface type). +// +// If key already exists in the dictionary and the score equals the existing +// score, it updates the value and score in the skip list and the dictionary. +// If the score is different, it only updates the value in the dictionary +// because the ranking doesn't change and there is no need for an update in the +// skip list. +// +// If the key doesn't exist in the dictionary, it adds the new key, value and score +// to the dictionary, increments the size of the dictionary by 1, and also adds +// the node to the skip list. +func (pq *ZSetNodes) InsertNode(score int, member string, value interface{}) error { + // Instantiate dictionary if it's not already + if pq.dict == nil { + pq.dict = make(map[string]*SkipListNodeValue) + } + + // Check if key exists in dictionary + if v, ok := pq.dict[member]; ok { + if v.score == score { + // Update value and score as the score remains the same + pq.skipList.delete(score, member) + pq.dict[member] = pq.skipList.insert(score, member, value) + } else { + // Ranking isn't altered, only update value + v.value = value + } + } else { // Key doesn't exist, create new key + pq.dict[member] = pq.skipList.insert(score, member, value) + pq.size++ // Increase size count by 1 + // Node is also added to the skip list + } + + // Returns nil as no specific error condition is checked in this function + return nil +} + +// RemoveNode is a method for ZSetNodes structure. +// This method aims to delete a node from both +// the dictionary (dict) and the skip list (skipList). +// +// The method receives one parameter: +// - member: a string that represents the key of the node +// to be removed from the ZSetNodes structure. +// +// The method follows these steps: +// 1. Check if a node with key 'member' exists in the dictionary. +// If not, or if the dictionary itself is nil, it returns an error +// (_const.ErrKeyNotFound) indicating that the node cannot be found. +// 2. If the node exists, it proceeds to remove the node from both the +// skip list and dictionary. +// 3. After the successful removal of the node, it returns nil indicating +// the success of the operation. +// +// The RemoveNode's primary purpose is to provide a way to securely and +// efficiently remove a node from the ZSetNodes structure. +func (pq *ZSetNodes) RemoveNode(member string) error { + // Check for existence of key in dictionary + v, ok := pq.dict[member] + if !ok || pq.dict == nil { + return _const.ErrKeyNotFound + } + + // Delete Node from the skip list and dictionary + pq.skipList.delete(v.score, member) + delete(pq.dict, member) + pq.size-- + + return nil +} + +func (pq *ZSetNodes) exists(score int, member string) bool { + v, ok := pq.dict[member] + + return ok && v.score == score +} + +// Bytes encodes the ZSetNodes instance into bytes using MessagePack +// binary serialization format. The encoded bytes can be used for +// storage or transmission. If the encoding operation fails, an +// error is returned. +func (pq *ZSetNodes) Bytes() ([]byte, error) { + var msgPack = encoding.NewMessagePackEncoder() + if encodingError := msgPack.Encode(pq); encodingError != nil { + return nil, encodingError + } + return msgPack.Bytes(), nil +} + +// FromBytes decodes the input byte slice into the ZSetNodes object using MessagePack. +// Returns an error if decoding fails, otherwise nil. +func (pq *ZSetNodes) FromBytes(b []byte) error { + return encoding.NewMessagePackDecoder(b).Decode(pq) +} + +// getZSetFromDB fetches and deserializes ZSetNodes from the database. +// +// Returns a pointer to the ZSetNodes and error, if any. +// If the key doesn't exist, both the pointer and the error will be nil. +// In case of deserialization errors, returns nil and the error. +func (zs *ZSetStructure) getZSetFromDB(key []byte) (*ZSetNodes, error) { + dbData, err := zs.db.Get(key) + + // If key is not found, return nil for both; otherwise return the error. + if err != nil { - // Since the key might not exist, we need to handle ErrKeyNotFound separately as it is a valid case - if err != nil && err != _const.ErrKeyNotFound { return nil, err } + + // Deserialize the data. var zSetValue ZSetNodes - // Deserialize the data into a list - err = encoding.DecodeMessagePack(dbData, zSetValue) - if err != nil { + if err := encoding.DecodeMessagePack(dbData, zSetValue); err != nil { return nil, err } + // return a pointer to the deserialized ZSetNodes, nil for the error return &zSetValue, nil } -func (l *ZSetStructure) setZSetToDB(key []byte, zSetValue ZSetNodes) error { - // Deserialize the data into a list + +// setZSetToDB writes a ZSetNodes object to the database. +// +// parameters: +// key: This is a byte slice that is used as a key in the database. +// zSetValue: This is a pointer to a ZSetNodes object that needs to be stored in the database. +// +// The function serializes the ZSetNodes object into MessagePack format. If an error occurs +// either during serialization or when writing to the database, that specific error is returned. +// If the process is successful, it returns nil. +func (zs *ZSetStructure) setZSetToDB(key []byte, zSetValue *ZSetNodes) error { val, err := encoding.EncodeMessagePack(zSetValue) if err != nil { return err } - err = l.db.Put(key, val) - if err != nil { - return err + return zs.db.Put(key, val) +} + +// UnmarshalBinary de-serializes the given byte slice into ZSetNodes instance +// it uses MessagePack format for de-serialization +// Returns an error if the decoding of size or insertion of node fails. +// +// Parameters: +// data : a slice of bytes to be decoded +// +// Returns: +// An error that will be nil if the function succeeds. +func (p *ZSetNodes) UnmarshalBinary(data []byte) (err error) { + // NewMessagePackDecoder creates a new MessagePack decoder with the provided data + dec := encoding.NewMessagePackDecoder(data) + + var size int + // Decode the size of the data structure + if err = dec.Decode(&size); err != nil { + return err // error handling if something goes wrong with decoding } - return nil + + // Iterate through each node in the data structure + for i := 0; i < size; i++ { + // Create an empty instance of SkipListNodeValue for each node + slValue := SkipListNodeValue{} + + // Decode each node onto the empty SkipListNodeValue instance + if err = dec.Decode(&slValue); err != nil { + return err // error handling if something goes wrong with decoding + } + + // Insert the decoded node into the ZSetNodes instance + if err = p.InsertNode(slValue.score, slValue.member, slValue.value); err != nil { + return err + } + } + return // if all nodes are correctly decoded and inserted, return with nil error } -func zSetNodesDecoder(value reflect.Value, i []byte) error { - bs := ZSetNode{} - var bytesRead int - num, s, err := encoding.DecodeString(i) +// MarshalBinary serializes the ZSetNodes instance into a byte slice. +// It uses MessagePack format for serialization +// Returns the serialized byte slice and an error if the encoding fails. +func (d *ZSetNodes) MarshalBinary() (_ []byte, err error) { + + // Initializing the MessagePackEncoder + enc := encoding.NewMessagePackEncoder() + + // Encoding the size attribute of d (i.e., d.size). The operation could fail, thus we check for an error. + // An error, if occurred, will be returned immediately, hence the flow of execution stops here. + err = enc.Encode(d.size) if err != nil { + return nil, err + } + + // This is the start of a loop going over all the nodes in d's skip list from the tail of the + // list to the head. + // The tail and head pointers refer to the last and first element of the list, respectively, + // and are maintained for efficient traversing of the list. + // we do that to get the elements in reverse order from biggest to the smallest for the best + // insertion efficiency as it makes the insertion O(1), because each new element to be inserted is + // the smallest yet. + x := d.skipList.tail + // as long as there are elements in the SkipList continue + for x != nil { + // Encoding the value of the current node in the skip list + // Again, if an error occurs it gets immediately returned, thus breaking the loop. + err = enc.Encode(x.value) + if err != nil { + return nil, err + } + + // Move to the previous node in the skip list. + x = x.prev + } + + // After the traversal of the skip list, the encoder should now hold the serialized representation of the + // ZSetNodes. Now, we return the bytes from the encoder along with any error that might have occurred + // during the encoding (should be nil if everything went fine). + return enc.Bytes(), err +} + +// UnmarshalBinary de-serializes the given byte slice into SkipListNodeValue instance +// It uses the MessagePack format for de-serialization +// Returns an error if the decoding of Key, Score, or Value fails. +func (p *SkipListNodeValue) UnmarshalBinary(data []byte) (err error) { + dec := encoding.NewMessagePackDecoder(data) + if err = dec.Decode(&p.member); err != nil { + return + } + if err = dec.Decode(&p.score); err != nil { return err } - bytesRead += num - bs.Value = s - val, num := binary.Varint(i[bytesRead:]) - bytesRead += num - bs.Index = int(val) - val, num = binary.Varint(i[bytesRead:]) - bytesRead += num - bs.Priority = int(val) - value.Set(reflect.ValueOf(bs)) - return nil + if err = dec.Decode(&p.value); err != nil { + return + } + return } -func zSetNodesEncoder(value reflect.Value) ([]byte, error) { - zsn := value.Interface().(ZSetNode) - if zsn.Value == "" { - return nil, errors.New("empty zset") + +// MarshalBinary uses MessagePack as the encoding format to serialize +// the SkipListNodeValue object into a byte array. +func (d *SkipListNodeValue) MarshalBinary() (_ []byte, err error) { + + // The NewMessagePackEncoder function is called to create a new + // MessagePack encoder. + enc := encoding.NewMessagePackEncoder() + + // Then, we try to encode the 'key' field of the SkipListNodeValue + // If an error occurs, it is returned immediately along with the + // currently encoded byte slice. + if err = enc.Encode(d.member); err != nil { + return enc.Bytes(), err } - buf := bytes.NewBuffer(nil) - es, err := encoding.EncodeString(zsn.Value) - if err != nil { - return nil, err + + // We do the same for the 'score' field. + if err = enc.Encode(d.score); err != nil { + return enc.Bytes(), err } - _, err = buf.Write(es) - if err != nil { - return nil, err + + // Lastly, the 'value' field is encoded in the same way. + if err = enc.Encode(d.value); err != nil { + return enc.Bytes(), err } - b := make([]byte, binary.MaxVarintLen64) - written := 0 - written += binary.PutVarint(b[:], int64(zsn.Index)) - written += binary.PutVarint(b[written:], int64(zsn.Priority)) - buf.Write(b[:written]) - return buf.Bytes(), nil + + // If everything goes well and we're done encoding, we return the + // final byte slice which represents the encoded SkipListNodeValue + // and a nil error. + return enc.Bytes(), err } diff --git a/structure/zset_test.go b/structure/zset_test.go index e8f0f025..264a4cd3 100644 --- a/structure/zset_test.go +++ b/structure/zset_test.go @@ -1,61 +1,200 @@ package structure import ( - "container/heap" + "github.com/ByteStorage/FlyDB/config" + _const "github.com/ByteStorage/FlyDB/lib/const" "github.com/stretchr/testify/assert" - "reflect" + "os" "testing" ) +func initZSetDB() (*ZSetStructure, *config.Options) { + opts := config.DefaultOptions + dir, _ := os.MkdirTemp("", "TestZSetStructure") + opts.DirPath = dir + hash, _ := NewZSetStructure(opts) + return hash, &opts +} + func TestSortedSet(t *testing.T) { - items := map[string]int{ - "banana": 3, "apple": 2, "pear": 4, - } - // Create a priority queue, put the items in it, and - // establish the priority queue (heap) invariants. - pq := ZSetNodes{} - pq = make([]*ZSetNode, len(items)) - i := 0 - for value, priority := range items { - pq[i] = &ZSetNode{ - Value: value, - Priority: priority, - Index: i, - } - i++ + type test struct { + name string + input map[string]int + want *ZSetNodes + expectError bool + } + + zs := newZSetNodes() + err := zs.InsertNode(3, "banana", "hello") + err = zs.InsertNode(1, "apple", "hello") + err = zs.InsertNode(2, "pear", "hello") + err = zs.InsertNode(44, "orange", "hello") + err = zs.InsertNode(9, "strawberry", "delish") + err = zs.InsertNode(15, "dragon-fruit", "nonDelish") + t.Log(zs.skipList.getRank(9, "strawberry")) + t.Log(zs.skipList.getNodeByRank(1)) + t.Log(zs.skipList.getNodeByRank(2)) + t.Log(zs.skipList.getNodeByRank(3)) + t.Log(zs.skipList.getNodeByRank(5)) + //var bufEnc bytes.Buffer + //enc := gob.NewEncoder(&bufEnc) + //err = enc.Encode(zs) + //assert.NoError(t, err) + b, err := zs.Bytes() + t.Log(b) + + fromBytes := newZSetNodes() + //buf := bytes.NewBuffer(bufEnc.Bytes()) + //gd := gob.NewDecoder(buf) + //err = gd.Decode(fromBytes.FromBytes(b)) + //assert.NoError(t, err) + + t.Log(fromBytes.FromBytes(b)) + //t.Log(fromBytes) + assert.NoError(t, err) + + tests := []test{ + { + name: "empty", + input: map[string]int{}, + want: &ZSetNodes{}, + expectError: false, + }, + { + name: "three fruits", + input: map[string]int{"banana": 3, "apple": 2, "pear": 4, "peach": 40}, + want: nil, + }, } - heap.Init(&pq) - pq.Push(&ZSetNode{"Pineapple", 50, 0}) - //heap.Fix(&pq, len(pq)-1) - //pq.update(pq[0], pq[0].value, 0) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.ElementsMatch(t, tt.want, nil) + }) + } - t.Log(pq) } func TestSortedSet_Bytes(t *testing.T) { - items := map[string]int{ - "banana": 3, "apple": 2, "pear": 4, - } - // Create a priority queue, put the items in it, and - // establish the priority queue (heap) invariants. - pq := ZSetNodes{} - pq = make([]*ZSetNode, len(items)) - i := 0 - for value, priority := range items { - pq[i] = &ZSetNode{ - Value: value, - Priority: priority, - Index: i, + +} + +func TestNewSkipList(t *testing.T) { + s := newSkipList() + + assert := assert.New(t) + assert.Equal(1, s.level) + assert.Nil(s.head.prev) + assert.Equal(0, s.head.value.score) + assert.Equal("", s.head.value.member) +} + +func TestNewSkipListNode(t *testing.T) { + score := 10 + key := "test_key" + value := "test_value" + level := 5 + + node := newSkipListNode(level, score, key, value) + + // Validate node's value + if node.value.score != score || node.value.member != key || node.value.value != value { + t.Errorf("Unexpected value in node, got: %v, want: {score: %d, key: %s, val: %s}.\n", node.value, score, key, value) + } + + // Validate node's level slice length + if len(node.level) != level { + t.Errorf("Unexpected length of node's level slice, got: %d, want: %d.\n", len(node.level), level) + } + + // Validate each SkipListLevel in the level slice + for _, l := range node.level { + if l.next != nil || l.span != 0 { + t.Errorf("Unexpected SkipListLevel, got: %v, want: {forward: nil, span: 0}.\n", l) } - i++ + } +} +func TestZAdd(t *testing.T) { + zs, _ := initZSetDB() + type testCase struct { + key string + score int + member string + value string + err error } - heap.Init(&pq) - b, err := pq.Bytes() - assert.NoError(t, err) - rb := ZSetNodes{} - err = rb.FromBytes(b) - assert.NoError(t, err) - assert.True(t, reflect.DeepEqual(rb, pq)) + testCases := []testCase{ + {"key", 10, "member", "value", nil}, + {"", 10, "member", "value", _const.ErrKeyIsEmpty}, + } + + for _, tc := range testCases { + err := zs.ZAdd(tc.key, tc.score, tc.member, tc.value) + // Adjust according to your error handling + if err != tc.err { + t.Errorf("Expected error to be %v, but got %v", tc.err, err) + } + } +} + +func TestSkipList_delete(t *testing.T) { + type deleteTest struct { + name string + score int + member string + targetList []testZSetNodeValue + inputList []testZSetNodeValue + } + + vals := []testZSetNodeValue{ + {score: 1, member: "mem1", value: nil}, + {score: 2, member: "mem2", value: nil}, + {score: 3, member: "mem3", value: nil}, + {score: 4, member: "mem4", value: nil}, + {score: 5, member: "mem5", value: nil}, + } + + // Omitted: Add some nodes into sl... + + tests := []deleteTest{ + { + name: "Delete Test 1", + score: 15, + member: "member1", + targetList: []testZSetNodeValue{{score: 3, member: "mem3"}}, // result of adding nodes into sl + inputList: vals, + }, + // Add more test cases here... + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + head := newZSetNodes() + populateSkipListFromSlice(head, test.inputList) + + for _, value := range test.targetList { + // check if the insertion has been performed + assert.True(t, head.exists(value.score, value.member)) + // delete the target members + assert.NoError(t, head.RemoveNode(value.member)) + // check to see if the deletion has been correctly performed + assert.False(t, head.exists(value.score, value.member)) + + } + }) + } +} + +type testZSetNodeValue struct { + score int + member string + value interface{} +} + +func populateSkipListFromSlice(nodes *ZSetNodes, zSetNodeValues []testZSetNodeValue) { + // Iterate over the zsetNodes array + for _, zSetNode := range zSetNodeValues { + _ = nodes.InsertNode(zSetNode.score, zSetNode.member, zSetNode.value) + } } From 1e758db1b642967fcb68d6a08c5c10149a17abd8 Mon Sep 17 00:00:00 2001 From: Saeid Aghapour Date: Tue, 18 Jul 2023 01:41:40 +0330 Subject: [PATCH 3/5] add test units and minor changes(#179) --- lib/encoding/messagepack.go | 48 +-- lib/encoding/messagepack_test.go | 136 ++++++++- structure/zset.go | 152 ++++++++-- structure/zset_test.go | 490 ++++++++++++++++++++++++++++--- 4 files changed, 727 insertions(+), 99 deletions(-) diff --git a/lib/encoding/messagepack.go b/lib/encoding/messagepack.go index 075ba371..46771491 100644 --- a/lib/encoding/messagepack.go +++ b/lib/encoding/messagepack.go @@ -20,33 +20,32 @@ type MessagePackCodec struct { // it manages IDs and counts of the encoded objects. type MessagePackCodecEncoder struct { MessagePackCodec // Embedded MessagePackCodec - - // nextId is used probably for tracking ID of the next object to encode. - nextId uint - - // objects represents the count of objects that have been encoded. - objects int } // MessagePackCodecDecoder struct, holds a reference to a MessagePackCodec instance. type MessagePackCodecDecoder struct { - MessagePackCodec + *MessagePackCodec +} + +// NewMsgPackHandle is a helper function to create a new instance of MsgpackHandle +func NewMsgPackHandle() *codec.MsgpackHandle { + return &codec.MsgpackHandle{} } // InitMessagePack function initializes MessagePackCodec struct and returns it. func InitMessagePack() MessagePackCodec { return MessagePackCodec{ - MsgPack: &codec.MsgpackHandle{}, + MsgPack: NewMsgPackHandle(), } } // NewMessagePackEncoder function creates new MessagePackCodecEncoder and initializes it. func NewMessagePackEncoder() *MessagePackCodecEncoder { - msgPack := &codec.MsgpackHandle{} + msgPack := NewMsgPackHandle() b := make([]byte, 0) return &MessagePackCodecEncoder{ MessagePackCodec: MessagePackCodec{ - MsgPack: &codec.MsgpackHandle{}, + MsgPack: msgPack, b: &b, enc: codec.NewEncoderBytes(&b, msgPack), }, @@ -56,10 +55,10 @@ func NewMessagePackEncoder() *MessagePackCodecEncoder { // NewMessagePackDecoder function takes in a byte slice, and returns a pointer to newly created // and initialized MessagePackCodecDecoder func NewMessagePackDecoder(b []byte) *MessagePackCodecDecoder { - msgPack := &codec.MsgpackHandle{} + msgPack := NewMsgPackHandle() return &MessagePackCodecDecoder{ - MessagePackCodec: MessagePackCodec{ - MsgPack: &codec.MsgpackHandle{}, + MessagePackCodec: &MessagePackCodec{ + MsgPack: msgPack, b: &b, dec: codec.NewDecoderBytes(b, msgPack), }, @@ -122,7 +121,7 @@ func (m *MessagePackCodec) AddExtension( func EncodeMessagePack(msg interface{}) ([]byte, error) { // Directly initialize the byte slice and encoder. b := make([]byte, 0) - enc := codec.NewEncoderBytes(&b, &codec.MsgpackHandle{}) + enc := codec.NewEncoderBytes(&b, NewMsgPackHandle()) // Attempt to encode the message. if err := enc.Encode(msg); err != nil { @@ -135,7 +134,7 @@ func EncodeMessagePack(msg interface{}) ([]byte, error) { // DecodeMessagePack function decodes a byte slice of MessagePack data into a given object. func DecodeMessagePack(in []byte, out interface{}) error { - dec := codec.NewDecoder(bytes.NewBuffer(in), &codec.MsgpackHandle{}) + dec := codec.NewDecoder(bytes.NewBuffer(in), NewMsgPackHandle()) return dec.Decode(out) } @@ -157,12 +156,19 @@ func EncodeString(s string) ([]byte, error) { return b, nil } -// DecodeString is a function that takes an input byte slice and attempts to decode it to obtain a string. -// Return parameters are an integer, a string and an error. Integer denotes the length of the byte slice -// representation of the string including length-field. The second return parameter is the decoded string. -// DecodeString raises an error if the length of byte slice is less than the expected string length plus -// one (considering the string length field) or if the provided byte slice is empty. -// If successful, returns length of byte representation of string, the decoded string and a nil error. +/* +DecodeString converts an input byte slice into a string. +Arguments: + + b: Input byte slice to be decoded. + +Returns: +- Integer: Length of byte representation of the string. +- String: Decoded string. +- Error: 'invalid length' if slice is empty or insufficient length, else nil. + +The function reads the first byte as string length (l), creates a slice of length l and returns the formed string. +*/ func DecodeString(b []byte) (int, string, error) { // Check that byte slice is not empty. if len(b) == 0 { diff --git a/lib/encoding/messagepack_test.go b/lib/encoding/messagepack_test.go index 1fb39808..0f987fbf 100644 --- a/lib/encoding/messagepack_test.go +++ b/lib/encoding/messagepack_test.go @@ -5,6 +5,7 @@ import ( "github.com/hashicorp/go-msgpack/codec" "github.com/hashicorp/raft" "github.com/stretchr/testify/assert" + "reflect" "testing" ) @@ -66,8 +67,8 @@ func TestMsgPackDecoder_Decode_ErrDecoderNotInitialized(t *testing.T) { msgPack := InitMessagePack() encoded, _ := msgPack.Encode("example message") decoder := &MessagePackCodecDecoder{ - MessagePackCodec: MessagePackCodec{ - MsgPack: &codec.MsgpackHandle{}, + MessagePackCodec: &MessagePackCodec{ + MsgPack: NewMsgPackHandle(), b: &encoded, }, } @@ -79,15 +80,6 @@ func TestMsgPackDecoder_Decode_ErrDecoderNotInitialized(t *testing.T) { assert.Equal(t, "decoder not initialized", err.Error()) } -func TestEncodeString(t *testing.T) { - s := "example string" - - encoded, err := EncodeString(s) - - assert.NotNil(t, encoded) - assert.Nil(t, err) -} - func TestEncodeString_ErrStringLength(t *testing.T) { s := bytes.Repeat([]byte("a"), 0x80) // 128-byte long string @@ -97,3 +89,125 @@ func TestEncodeString_ErrStringLength(t *testing.T) { assert.NotNil(t, err) assert.Equal(t, "invalid string length", err.Error()) } + +func TestMessagePackCodec_Encode(t *testing.T) { + type TestStruct struct { + Field1 string + Field2 int + } + assert := assert.New(t) + + codec := &MessagePackCodec{ + MsgPack: NewMsgPackHandle(), + } + + t.Run("successful encoding", func(t *testing.T) { + testStruct := &TestStruct{ + Field1: "Test", + Field2: 1, + } + outStruct := &TestStruct{} + + b, err := codec.Encode(testStruct) + assert.Nil(err, "Error should be nil") + err = codec.Decode(b, outStruct) + assert.NoError(err) + assert.EqualValues(testStruct, outStruct) + }) + +} +func TestEncodeString(t *testing.T) { + tests := []struct { + input string + expected []byte + hasErr bool + }{ + {"hello", []byte{0x05, 'h', 'e', 'l', 'l', 'o'}, false}, + {"world", []byte{0x05, 'w', 'o', 'r', 'l', 'd'}, false}, + {string(make([]byte, 0x80)), nil, true}, + } + + for _, tt := range tests { + result, err := EncodeString(tt.input) + if (err != nil) != tt.hasErr { + t.Errorf("EncodeString(%q) error = %v, wantErr %v", tt.input, err, tt.hasErr) + continue + } + if !tt.hasErr && string(result) != string(tt.expected) { + t.Errorf("EncodeString(%q) = %q, want %q", tt.input, result, tt.expected) + } + } +} + +func TestDecodeString(t *testing.T) { + tests := []struct { + input []byte + expectedL int + expected string + hasErr bool + }{ + {[]byte{0x05, 'h', 'e', 'l', 'l', 'o'}, 6, "hello", false}, + {[]byte{0x05, 'w', 'o', 'r', 'l', 'd'}, 6, "world", false}, + {[]byte{0x05, 'w', 'o', 'r'}, 0, "", true}, + {[]byte{}, 0, "", true}, + } + + for _, tt := range tests { + length, result, err := DecodeString(tt.input) + if (err != nil) != tt.hasErr { + t.Errorf("DecodeString(%v) error = %v, wantErr %v", tt.input, err, tt.hasErr) + continue + } + if !tt.hasErr && (result != tt.expected || length != tt.expectedL) { + t.Errorf("DecodeString(%v) = %v,%q, want %v,%q", tt.input, length, result, tt.expectedL, tt.expected) + } + } +} +func TestMessagePackCodecEncoder_Encode(t *testing.T) { + var mh codec.MsgpackHandle + mh.MapType = reflect.TypeOf(map[int]int{}) + + encoder := NewMessagePackEncoder() + + err := encoder.Encode(map[int]int{1: 2}) + assert.NoError(t, err) + + err = encoder.Encode(map[int]int{3: 4}) + assert.NoError(t, err) +} + +func TestAddExtension(t *testing.T) { + + type CustomType struct { + Name string + } + + // global extention info. + const extensionID byte = 1 + + encoder := func(rv reflect.Value) ([]byte, error) { + ct := rv.Interface().(CustomType) + return []byte(ct.Name), nil + } + decoder := func(rv reflect.Value, b []byte) error { + rv.Set(reflect.ValueOf(CustomType{Name: string(b)})) + return nil + } + + m := NewMessagePackEncoder() + err := m.AddExtension(reflect.TypeOf(CustomType{}), extensionID, encoder, decoder) + if err != nil { + t.Fatalf("Failed adding extension: %v", err) + } + data := CustomType{Name: "test"} + dataVerify := CustomType{} + err = m.enc.Encode(&data) + assert.NoError(t, err) + assert.NotNil(t, m.Bytes()) + d := NewMessagePackDecoder(m.Bytes()) + err = d.AddExtension(reflect.TypeOf(CustomType{}), extensionID, encoder, decoder) + assert.NoError(t, err) + err = d.Decode(&dataVerify) + assert.NoError(t, err) + assert.EqualValues(t, dataVerify, data) +} diff --git a/structure/zset.go b/structure/zset.go index 236eae33..1494e46b 100644 --- a/structure/zset.go +++ b/structure/zset.go @@ -333,20 +333,23 @@ func (sl *SkipList) getRange(start int, end int, reverse bool) (nv []SkipListNod if start > end { return } - if end < 0 { + if end <= 0 { return nil // todo unexpected behavior, we can set it to zero as well } + node := sl.head if reverse { - node = sl.getNodeByRank(end) - } else { - node = sl.getNodeByRank(start) - } - if reverse { - node = sl.getNodeByRank(end) + node = sl.tail + if start > 0 { + node = sl.getNodeByRank(sl.size - start) + } } else { - node = sl.getNodeByRank(start) + node = node.level[0].next + if start > 0 { + node = sl.getNodeByRank(start + 1) + } } + for i := start; i < end; i++ { if reverse { nv = append(nv, *node.value) @@ -503,6 +506,20 @@ func (zs *ZSetStructure) ZAdd(key string, score int, member string, value string return nil } +func (zs *ZSetStructure) exists(key string, score int, member string) bool { + if len(key) == 0 { + return false + } + keyBytes := stringToBytesWithKey(key) + + zSet, err := zs.getZSetFromDB(keyBytes) + + if err != nil { + return false + } + + return zSet.exists(score, member) +} /* ZRem is a method belonging to ZSetStructure that removes a member from a ZSet. @@ -545,7 +562,7 @@ func (zs *ZSetStructure) ZScore(key string, member string) (int, error) { zSet, err := zs.getZSetFromDB(keyBytes) if err != nil { - return 0, fmt.Errorf("failed to get or create ZSet from DB with key '%v': %w", key, err) + return 0, err } // if the member in the sorted set is found, return the score associated with it if v, ok := zSet.dict[member]; ok { @@ -601,6 +618,25 @@ func (zs *ZSetStructure) ZRank(key string, member string) (int, error) { // rank zero means no rank found return 0, _const.ErrKeyNotFound } + +// ZRevRank calculates the reverse rank of a member in a ZSet (Sorted Set) associated with a given key. +// ZSet exploits the Sorted Set data structure of Redis with O(log(N)) time complexity for Fetching the rank. +// +// Parameters: +// +// key: This is a string that serves as the key of a ZSet stored in the database. +// member: This is a string that represents a member of a ZSet whose rank needs to be obtained. +// +// Returns: +// +// int: The integer represents the reverse rank of the member in the ZSet. It returns 0 if the member is not found in the ZSet. +// On successful execution, it returns the difference of the ZSet size and the member's rank. +// error: The error which will be null if no errors occurred. If the key provided is empty, an ErrKeyIsEmpty error is returned. +// If there's a problem getting or creating the ZSet from the database, an error message is returned with the format +// "failed to get or create ZSet from DB with key '%v': %w", where '%v' is the key and '%w' shows the error detail. +// If the member is not found in the ZSet, it returns an ErrKeyNotFound error. +// +// Note: The reverse rank is calculated as 'size - rank', and the ranks start from 1. func (zs *ZSetStructure) ZRevRank(key string, member string) (int, error) { if len(key) == 0 { return 0, _const.ErrKeyIsEmpty @@ -613,12 +649,43 @@ func (zs *ZSetStructure) ZRevRank(key string, member string) (int, error) { } if v, ok := zSet.dict[member]; ok { rank := zSet.skipList.getRank(v.score, member) - return zSet.size - rank, nil + return (zSet.size) - rank + 1, nil } // rank zero means no rank found return 0, _const.ErrKeyNotFound } + +// ZRange retrieves a specific range of elements from a sorted set (ZSet) denoted by a specific key. +// It returns a slice of SkipListNodeValue containing the elements within the specified range (inclusive), and a nil error when successful. +// +// The order of the returned elements is based on their rank in the set, not their score. +// +// Parameters: +// +// key: A string identifier representing the ZSet. The key shouldn't be an empty string. +// start: A zero-based integer representing the first index of the range. +// end: A zero-based integer representing the last index of the range. +// +// Returns: +// +// []SkipListNodeValue: +// Slice of SkipListNodeValue containing elements within the specified range. +// error: +// An error if it occurs during execution, such as: +// 1. The provided key string is empty. +// 2. An error occurs while fetching the ZSet from the database, i.e., the ZSet represented by the given key doesn't exist. +// In the case of an error, an empty slice and the actual error encountered will be returned. +// +// Note: +// On successful execution, ZRange returns the elements starting from 'start' index up to 'end' index inclusive. +// If the set doesn't exist or an error occurs during execution, ZRange returns an empty slice and the error. +// +// Example: +// Assume we have ZSet with the following elements: ["element1", "element2", "element3", "element4"] +// ZRange("someKey", 0, 2) will return ["element1", "element2", "element3"] and nil error. +// +// This method is part of the ZSetStructure type. func (zs *ZSetStructure) ZRange(key string, start int, end int) ([]SkipListNodeValue, error) { if len(key) == 0 { return nil, _const.ErrKeyIsEmpty @@ -627,14 +694,25 @@ func (zs *ZSetStructure) ZRange(key string, start int, end int) ([]SkipListNodeV zSet, err := zs.getZSetFromDB(keyBytes) if err != nil { - return nil, fmt.Errorf("failed to get or create ZSet from DB with key '%v': %w", key, err) + return nil, err } r := zSet.skipList.getRange(start, end, false) // rank zero means no rank found return r, nil } -func (zs *ZSetStructure) ZRevRange(key string, start int, end int) ([]SkipListNodeValue, error) { + +// ZRevRange retrieves a range of elements from a sorted set (ZSet) in descending order. +// Inputs: +// - key: Name of the ZSet +// - startRank: Initial rank of the desired range +// - endRank: Final rank of the desired range +// +// Output: +// - An array of SkipListNodeValue, representing elements from the range [startRank, endRank] in descending order +// - Error if an issue occurs, such as when the key is empty or ZSet retrieval fails +// error +func (zs *ZSetStructure) ZRevRange(key string, startRank int, endRank int) ([]SkipListNodeValue, error) { if len(key) == 0 { return nil, _const.ErrKeyIsEmpty } @@ -642,13 +720,16 @@ func (zs *ZSetStructure) ZRevRange(key string, start int, end int) ([]SkipListNo zSet, err := zs.getZSetFromDB(keyBytes) if err != nil { - return nil, fmt.Errorf("failed to get or create ZSet from DB with key '%v': %w", key, err) + return nil, err } - r := zSet.skipList.getRange(start, end, true) + r := zSet.skipList.getRange(startRank, endRank, true) // rank zero means no rank found return r, nil } + +// The ZCard function returns the size of the dictionary of the sorted set stored at key in the database. +// It takes a string key as an argument. func (zs *ZSetStructure) ZCard(key string) (int, error) { if len(key) == 0 { return 0, _const.ErrKeyIsEmpty @@ -657,23 +738,47 @@ func (zs *ZSetStructure) ZCard(key string) (int, error) { zSet, err := zs.getZSetFromDB(keyBytes) if err != nil { - return 0, fmt.Errorf("failed to get or create ZSet from DB with key '%v': %w", key, err) + return 0, err } // get the size of the dictionary return zSet.size, nil } + +// ZIncrBy increases the score of an existing member in a sorted set stored at specified key by +// the increment `incBy` provided. If member does not exist, ErrKeyNotFound error is returned. +// If the key does not exist, it treats it as an empty sorted set and returns an error. +// +// The method accepts three parameters: +// `key`: a string type parameter that identifies the sorted set +// `member`: a string type parameter representing member in the sorted set +// `incBy`: an int type parameter provides the increment value for a member score +// +// The method throws error under following circumstances - +// if provided key is empty (ErrKeyIsEmpty error), +// if provided key or member is not present in the database (ErrKeyNotFound error), +// if it's unable to fetch or create ZSet from DB, +// if there's an issue with node insertion, +// if unable to set ZSet to DB post increment operation func (zs *ZSetStructure) ZIncrBy(key string, member string, incBy int) error { if len(key) == 0 { return _const.ErrKeyIsEmpty } + keyBytes := stringToBytesWithKey(key) zSet, err := zs.getZSetFromDB(keyBytes) if err != nil { return fmt.Errorf("failed to get or create ZSet from DB with key '%v': %w", key, err) } + if v, ok := zSet.dict[member]; ok { - return zSet.InsertNode(v.score+incBy, member, v.value) + if err = zSet.InsertNode(v.score+incBy, member, v.value); err != nil { + return err + } + if err = zs.setZSetToDB(keyBytes, zSet); err != nil { + return err + } + return nil } return _const.ErrKeyNotFound @@ -728,10 +833,13 @@ func (pq *ZSetNodes) InsertNode(score int, member string, value interface{}) err if pq.dict == nil { pq.dict = make(map[string]*SkipListNodeValue) } + if pq.skipList == nil { + pq.skipList = newSkipList() + } // Check if key exists in dictionary if v, ok := pq.dict[member]; ok { - if v.score == score { + if v.score != score { // Update value and score as the score remains the same pq.skipList.delete(score, member) pq.dict[member] = pq.skipList.insert(score, member, value) @@ -820,12 +928,13 @@ func (zs *ZSetStructure) getZSetFromDB(key []byte) (*ZSetNodes, error) { return nil, err } - + dec := encoding.NewMessagePackDecoder(dbData) // Deserialize the data. var zSetValue ZSetNodes - if err := encoding.DecodeMessagePack(dbData, zSetValue); err != nil { + if err = dec.Decode(&zSetValue); err != nil { return nil, err } + // return a pointer to the deserialized ZSetNodes, nil for the error return &zSetValue, nil } @@ -840,11 +949,12 @@ func (zs *ZSetStructure) getZSetFromDB(key []byte) (*ZSetNodes, error) { // either during serialization or when writing to the database, that specific error is returned. // If the process is successful, it returns nil. func (zs *ZSetStructure) setZSetToDB(key []byte, zSetValue *ZSetNodes) error { - val, err := encoding.EncodeMessagePack(zSetValue) + val := encoding.NewMessagePackEncoder() + err := val.Encode(zSetValue) if err != nil { return err } - return zs.db.Put(key, val) + return zs.db.Put(key, val.Bytes()) } // UnmarshalBinary de-serializes the given byte slice into ZSetNodes instance diff --git a/structure/zset_test.go b/structure/zset_test.go index 264a4cd3..87d0d8ae 100644 --- a/structure/zset_test.go +++ b/structure/zset_test.go @@ -4,7 +4,10 @@ import ( "github.com/ByteStorage/FlyDB/config" _const "github.com/ByteStorage/FlyDB/lib/const" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "math/rand" "os" + "reflect" "testing" ) @@ -31,28 +34,13 @@ func TestSortedSet(t *testing.T) { err = zs.InsertNode(44, "orange", "hello") err = zs.InsertNode(9, "strawberry", "delish") err = zs.InsertNode(15, "dragon-fruit", "nonDelish") - t.Log(zs.skipList.getRank(9, "strawberry")) - t.Log(zs.skipList.getNodeByRank(1)) - t.Log(zs.skipList.getNodeByRank(2)) - t.Log(zs.skipList.getNodeByRank(3)) - t.Log(zs.skipList.getNodeByRank(5)) - //var bufEnc bytes.Buffer - //enc := gob.NewEncoder(&bufEnc) - //err = enc.Encode(zs) - //assert.NoError(t, err) b, err := zs.Bytes() - t.Log(b) fromBytes := newZSetNodes() - //buf := bytes.NewBuffer(bufEnc.Bytes()) - //gd := gob.NewDecoder(buf) - //err = gd.Decode(fromBytes.FromBytes(b)) - //assert.NoError(t, err) - - t.Log(fromBytes.FromBytes(b)) - //t.Log(fromBytes) + err = fromBytes.FromBytes(b) assert.NoError(t, err) - + assert.NotNil(t, fromBytes.skipList) + assert.Equal(t, fromBytes.size, zs.size) tests := []test{ { name: "empty", @@ -79,14 +67,404 @@ func TestSortedSet_Bytes(t *testing.T) { } +func TestZRem(t *testing.T) { + mockZSetStructure, _ := initZSetDB() + + // 1. Test for Key is Empty + err := mockZSetStructure.ZRem("", "member") + require.Error(t, err) + require.Equal(t, _const.ErrKeyIsEmpty, err) + type testCase struct { + key string + score int + member string + value string + err error + } + + testCases := []testCase{ + {"key", 10, "member", "value", nil}, + {"", 10, "member", "value", _const.ErrKeyIsEmpty}, + } + + for _, tc := range testCases { + t.Run(tc.key, func(t *testing.T) { + err := mockZSetStructure.ZAdd(tc.key, tc.score, tc.member, tc.value) + // check to see if element added + assert.Equal(t, tc.err, err) + if tc.err == nil { + // check if member added + assert.True(t, mockZSetStructure.exists(tc.key, tc.score, tc.member)) + } + + }) + } +} +func TestZAdd(t *testing.T) { + zs, _ := initZSetDB() + type testCase struct { + key string + score int + member string + value string + want SkipListNodeValue + err error + } + + testCases := []testCase{ + { + "key", + 10, + "member", + "value", + SkipListNodeValue{member: "member"}, + nil, + }, + { + "", + 10, + "member", + "value", + SkipListNodeValue{member: ""}, + _const.ErrKeyIsEmpty, + }, + } + + for _, tc := range testCases { + t.Run(tc.key, func(t *testing.T) { + err := zs.ZAdd(tc.key, tc.score, tc.member, tc.value) + assert.Equal(t, tc.err, err) + if tc.err == nil { + // check if member added + assert.True(t, zs.exists(tc.key, tc.score, tc.member)) + err = zs.ZRem(tc.key, tc.member) + assert.NoError(t, err) + // should be removed successfully + assert.False(t, zs.exists(tc.key, tc.score, tc.member)) + } + // Adjust according to your error handling + + }) + } +} +func TestZIncrBy(t *testing.T) { + zs, _ := initZSetDB() + err := zs.ZIncrBy("", "non-existingMember", 5) + if err == nil { + t.Error("Expected error for empty key not returned") + } + + err = zs.ZIncrBy("key", "non-existingMember", 5) + if !assert.ErrorIs(t, err, _const.ErrKeyNotFound) { + t.Error("Expected ErrKeyNotFound for non-existing member not returned") + } + err = zs.ZAdd("key", 1, "existingMember", "") + assert.NoError(t, err) + err = zs.ZIncrBy("key", "existingMember", 5) + assert.NoError(t, err) + + Zset, err := zs.getZSetFromDB(stringToBytesWithKey("key")) + assert.Equal(t, 6, Zset.dict["existingMember"].score) +} +func TestZRank(t *testing.T) { + zs, _ := initZSetDB() + + // Assume that ZAdd adds a member to a set and assigns the member a score. + // Here the score does not matter + err := zs.ZAdd("myKey", 1, "member1", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 2, "member2", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 3, "member3", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 4, "member4", "") + assert.NoError(t, err) + + // Test when member is present in the set + rank, err := zs.ZRank("myKey", "member1") + assert.NoError(t, err) // no error should occur + assert.Equal(t, 1, rank) // as we inserted 'member1' first, its rank should be 1 + + // Test when member is not present in the set + rank, err = zs.ZRank("myKey", "unavailableMember") + assert.Error(t, err) // an error should occur + assert.Equal(t, 0, rank) // as 'unavailableMember' is not part of set, rank should be 0 + + // Test with an empty key + rank, err = zs.ZRank("", "member") + assert.Error(t, err) // an error should occur + assert.Equal(t, 0, rank) // rank should be 0 for invalid key} + + // Test member2 which should be 2nd + rank, err = zs.ZRank("myKey", "member2") + assert.NoError(t, err) // there should be no errors + assert.Equal(t, 2, rank) // rank should be 2 for key `member2` + + // Test member3 which should be 3rd + rank, err = zs.ZRank("myKey", "member3") + assert.NoError(t, err) // there should be no errors + assert.Equal(t, 3, rank) + + // remove member2 and test `member3` which should become 2 + err = zs.ZRem("myKey", "member2") + assert.NoError(t, err) // there should be no errors + rank, err = zs.ZRank("myKey", "member3") + assert.NoError(t, err) // there should be no errors + assert.Equal(t, 2, rank) // now `member3` should become 2nd +} +func TestZRevRank(t *testing.T) { + zs, _ := initZSetDB() + + // Assume that ZAdd adds a member to a set and assigns the member a score. + // Here the score does not matter + err := zs.ZAdd("myKey", 1, "member1", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 2, "member2", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 3, "member3", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 4, "member4", "") + assert.NoError(t, err) + + // Test when member is present in the set + rank, err := zs.ZRevRank("myKey", "member3") + assert.NoError(t, err) // no error should occur + assert.Equal(t, 2, rank) // as we inserted 'member1' first, its rank should be 1 + + // Test when member is not present in the set + rank, err = zs.ZRevRank("myKey", "unavailableMember") + assert.Error(t, err) // an error should occur + assert.Equal(t, 0, rank) // as 'unavailableMember' is not part of set, rank should be 0 + + // Test with an empty key + rank, err = zs.ZRevRank("", "member") + assert.Error(t, err) // an error should occur + assert.Equal(t, 0, rank) // rank should be 0 for invalid key} + + // Test member2 which should be 2nd + rank, err = zs.ZRevRank("myKey", "member1") + assert.NoError(t, err) // there should be no errors + assert.Equal(t, 4, rank) // rank should be 2 for key `member2` + + // Test member3 which should be 3rd + rank, err = zs.ZRevRank("myKey", "member4") + assert.NoError(t, err) // there should be no errors + assert.Equal(t, 1, rank) + + // remove member2 and test `member3` which should become 2 + err = zs.ZRem("myKey", "member2") + assert.NoError(t, err) // there should be no errors + rank, err = zs.ZRevRank("myKey", "member3") + assert.NoError(t, err) // there should be no errors + assert.Equal(t, 2, rank) // now `member3` should become 2nd +} +func TestZRevRange(t *testing.T) { + zs, _ := initZSetDB() + + err := zs.ZAdd("myKey", 1, "member1", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 2, "member2", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 3, "member3", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 4, "member4", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 5, "member5", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 6, "member6", "") + assert.NoError(t, err) + + var n []uint8 + tests := []struct { + key string + start int + end int + want []SkipListNodeValue + wantErr error + }{ + {"myKey", 0, 3, []SkipListNodeValue{ + {6, "member6", n}, + {5, "member5", n}, + {4, "member4", n}, + }, nil}, + {"", 0, 2, nil, _const.ErrKeyIsEmpty}, + {"fail", 0, 2, nil, _const.ErrKeyNotFound}, + } + + for _, tt := range tests { + t.Run(tt.key, func(t *testing.T) { + got, err := zs.ZRevRange(tt.key, tt.start, tt.end) + + if !reflect.DeepEqual(err, tt.wantErr) { + t.Errorf("ZRange() error = %v, wantErr %v", err, tt.wantErr) + } + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ZRange() = %v, want %v", got, tt.want) + } + }) + } +} +func TestZRange(t *testing.T) { + zs, _ := initZSetDB() + + err := zs.ZAdd("myKey", 1, "member1", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 2, "member2", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 3, "member3", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 4, "member4", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 5, "member5", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 6, "member6", "") + assert.NoError(t, err) + var n []uint8 + tests := []struct { + key string + start int + end int + want []SkipListNodeValue + wantErr error + }{ + {"myKey", 0, 3, []SkipListNodeValue{ + {1, "member1", n}, + {2, "member2", n}, + {3, "member3", n}, + }, nil}, + {"", 0, 2, nil, _const.ErrKeyIsEmpty}, + {"fail", 0, 2, nil, _const.ErrKeyNotFound}, + } + + for _, tt := range tests { + t.Run(tt.key, func(t *testing.T) { + got, err := zs.ZRange(tt.key, tt.start, tt.end) + + if !reflect.DeepEqual(err, tt.wantErr) { + t.Errorf("ZRange() error = %v, wantErr %v", err, tt.wantErr) + } + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ZRange() = %v, want %v", got, tt.want) + } + }) + } +} +func TestZCard(t *testing.T) { + zs, _ := initZSetDB() + + err := zs.ZAdd("myKey", 1, "member1", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 2, "member2", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 3, "member3", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 4, "member4", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 5, "member5", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 6, "member6", "") + assert.NoError(t, err) + tests := []struct { + name string + key string + want int + wantErr error + }{ + {"Empty Key", "", 0, _const.ErrKeyIsEmpty}, + {"Non-Existent Key", "nonExist", 0, _const.ErrKeyNotFound}, + {"Existing Key", "myKey", 6, nil}, + } + + for _, tt := range tests { + t.Run(tt.key, func(t *testing.T) { + got, err := zs.ZCard(tt.key) + + if tt.want != got { + t.Fatalf("expected %d, got %d", tt.want, got) + } + + if tt.wantErr != nil { + if err == nil || tt.wantErr.Error() != err.Error() { + t.Fatalf("expected error '%v', got '%v'", tt.wantErr, err) + } + + } else if err != nil { + t.Fatalf("expected no error, got error '%v'", err) + } + }) + } +} +func TestZScore(t *testing.T) { + zs, _ := initZSetDB() + + err := zs.ZAdd("myKey", 1, "member1", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 2, "member2", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 3, "member3", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 4, "member4", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 5, "member5", "") + assert.NoError(t, err) + + err = zs.ZAdd("myKey", 6, "member6", "") + assert.NoError(t, err) + tests := []struct { + expectError error + expectedScore int + key string + member string + }{ + {_const.ErrKeyIsEmpty, 0, "", "member1"}, + {_const.ErrKeyNotFound, 0, "key1", "foo"}, + {nil, 1, "myKey", "member1"}, + {nil, 2, "myKey", "member2"}, + } + + for _, test := range tests { + t.Run(test.key, func(t *testing.T) { + score, err := zs.ZScore(test.key, test.member) + assert.Equal(t, test.expectError, err) + assert.Equal(t, test.expectedScore, score) + }) + } +} func TestNewSkipList(t *testing.T) { s := newSkipList() - assert := assert.New(t) - assert.Equal(1, s.level) - assert.Nil(s.head.prev) - assert.Equal(0, s.head.value.score) - assert.Equal("", s.head.value.member) + assertions := assert.New(t) + assertions.Equal(1, s.level) + assertions.Nil(s.head.prev) + assertions.Equal(0, s.head.value.score) + assertions.Equal("", s.head.value.member) } func TestNewSkipListNode(t *testing.T) { @@ -114,29 +492,6 @@ func TestNewSkipListNode(t *testing.T) { } } } -func TestZAdd(t *testing.T) { - zs, _ := initZSetDB() - type testCase struct { - key string - score int - member string - value string - err error - } - - testCases := []testCase{ - {"key", 10, "member", "value", nil}, - {"", 10, "member", "value", _const.ErrKeyIsEmpty}, - } - - for _, tc := range testCases { - err := zs.ZAdd(tc.key, tc.score, tc.member, tc.value) - // Adjust according to your error handling - if err != tc.err { - t.Errorf("Expected error to be %v, but got %v", tc.err, err) - } - } -} func TestSkipList_delete(t *testing.T) { type deleteTest struct { @@ -198,3 +553,46 @@ func populateSkipListFromSlice(nodes *ZSetNodes, zSetNodeValues []testZSetNodeVa _ = nodes.InsertNode(zSetNode.score, zSetNode.member, zSetNode.value) } } +func TestRandomLevel(t *testing.T) { + rand.Seed(1) + + for i := 0; i < 1000; i++ { + level := randomLevel() + if level < 1 || level > SKIPLIST_MAX_LEVEL { + t.Errorf("Generated level out of range: %v", level) + } + } +} +func TestZSetNodes_InsertNode(t *testing.T) { + pq := &ZSetNodes{} + + // Case 1: Insert new node + err := pq.InsertNode(1, "test", "value") + if err != nil { + t.Error("Failed when inserting a new node") + } + + if _, ok := pq.dict["test"]; !ok { + t.Error("Insert node failed, expected key to exist in dictionary") + } + + // Case 2: Update existing node with same score + err = pq.InsertNode(1, "test", "newvalue") + if err != nil { + t.Error("Failed when updating a score with same value") + } + + if v, ok := pq.dict["test"]; !ok || v.value != "newvalue" { + t.Error("Update node failed, expected value to be updated") + } + + // Case 3: Insert node with existing key but different score + err = pq.InsertNode(2, "test", "newvalue") + if err != nil { + t.Error("Failed when updating a score with a new value") + } + + if v, ok := pq.dict["test"]; !ok || v.score != 2 { + t.Error("Update node failed, expected score to be updated") + } +} From ba796c8fcc9410e7b9871b4284454aba3b02d1c2 Mon Sep 17 00:00:00 2001 From: Saeid Aghapour Date: Tue, 18 Jul 2023 11:40:31 +0330 Subject: [PATCH 4/5] implements data structure, Set(#138) --- structure/zset.go | 425 ++++++++++++++++++++++++++++------------- structure/zset_test.go | 319 +++++++++++++++++++++++++++++-- 2 files changed, 595 insertions(+), 149 deletions(-) diff --git a/structure/zset.go b/structure/zset.go index 1494e46b..9f5d6052 100644 --- a/structure/zset.go +++ b/structure/zset.go @@ -4,12 +4,12 @@ import ( "errors" "fmt" "github.com/ByteStorage/FlyDB/config" - "math" - "math/rand" - "github.com/ByteStorage/FlyDB/engine" _const "github.com/ByteStorage/FlyDB/lib/const" "github.com/ByteStorage/FlyDB/lib/encoding" + "math" + "math/rand" + "time" ) const ( @@ -28,25 +28,25 @@ type ZSetStructure struct { db *engine.DB } -// ZSetNodes represents a specific data structure in the database, which is key to handling sorted sets (ZSets). +// FZSet represents a specific data structure in the database, which is key to handling sorted sets (ZSets). // This struct facilitates interactions with data stored in the sorted set, allowing for both complex and simple operations. // // It contains three struct fields: // -// - 'dict': A Go map with string keys and pointers to SkipListNodeValue values. This map aims to provide quick access to +// - 'dict': A Go map with string keys and pointers to ZSetValue values. This map aims to provide quick access to // individual values in the sorted set based on the provided key. // -// - 'size': An integer value representing the current size (number of elements) in the ZSetNodes struct. This information is efficiently +// - 'size': An integer value representing the current size (number of elements) in the FZSet struct. This information is efficiently // kept track of whenever elements are added or removed from the set, so no separate computation is needed to retrieve this information. // // - 'skipList': A pointer towards a SkipList struct. SkipLists perform well under numerous operations, such as insertion, deletion, and searching. They are // a crucial component in maintaining the sorted set in a practical manner. In this context, the SkipList is used to keep an ordered track of the elements -// in the ZSetNodes struct. -type ZSetNodes struct { +// in the FZSet struct. +type FZSet struct { // dict field is a map where the key is a string and - // the value is a pointer to SkipListNodeValue instances, + // the value is a pointer to ZSetValue instances, // codified with the tag "dict". - dict map[string]*SkipListNodeValue `codec:"dict"` + dict map[string]*ZSetValue `codec:"dict"` // size field represents the quantity of elements within // the structure, codified with the tag "size". @@ -89,7 +89,7 @@ type SkipListLevel struct { // it forms a network of nodes, where traversal of the skip list is possible both forwards and backwards. // - 'level': This is an array (slice) of pointers towards SkipListLevel structures. Each element corresponds to a level of the skip list, // embedding the 'next' node at that same level, and the span between the current node and that 'next' node. -// - 'value': This is a pointer towards a single SkipListNodeValue structure. It holds the actual payload of the node +// - 'value': This is a pointer towards a single ZSetValue structure. It holds the actual payload of the node // (namely the 'score', 'key', and 'value' properties used in the context of Redis Sorted Sets), as well as provides the basis for ordering of nodes in the skip list. type SkipListNode struct { // prev is a pointer to the previous node in the skip list. @@ -99,17 +99,17 @@ type SkipListNode struct { // Each level represents a forward pointer to the next node in the current list level. level []*SkipListLevel - // value is a pointer to the SkipListNodeValue. + // value is a pointer to the ZSetValue. // This represents the value that this node holds. - value *SkipListNodeValue + value *ZSetValue } -// SkipListNodeValue is a struct used in the SkipList data structure. In the context of Redis Sorted Set (ZSet) implementation, -// it represents a single node value in the skip list. A SkipListNodeValue has three members: +// ZSetValue is a struct used in the SkipList data structure. In the context of Redis Sorted Set (ZSet) implementation, +// it represents a single node value in the skip list. A ZSetValue has three members: // - 'score' which is an integer representing the score of the node. Nodes in a skip list are ordered by this score in ascending order. // - 'member' which is a string defining the key of the node. For nodes with equal scores, order is determined with lexicographical comparison of keys. // - 'value' which is an interface{}, meaning it can hold any data type. This represents the actual value of the node in the skip list. -type SkipListNodeValue struct { +type ZSetValue struct { // Score is typically used for sorting purposes. Nodes with higher scores will be placed higher in the skip list. score int @@ -135,12 +135,12 @@ type SkipListNodeValue struct { func randomLevel() int { // Initialize level to 1 level := 1 - + r := rand.New(rand.NewSource(time.Now().UnixNano())) // Calculate the threshold for level. It's derived from the probability constant of the skip list. thresh := int(math.Round(SKIPLIST_PROB * 0xFFF)) // While a randomly generated number is less than this threshold, increment the level. - for int(rand.Int31()&0xFFF) < thresh { + for int(r.Int31()&0xFFF) < thresh { level++ } @@ -162,14 +162,14 @@ func NewZSetStructure(options config.Options) (*ZSetStructure, error) { return &ZSetStructure{db: db}, nil } -// newZSetNodes is a function that creates a new ZSetNodes object and returns a pointer to it. +// newZSetNodes is a function that creates a new FZSet object and returns a pointer to it. // It initializes the dictionary member dict of the newly created object to an empty map. -// The map is intended to map strings to pointers of SkipListNodeValue objects. -// size member of the object is set to 0, indicating that the ZSetNodes object is currently empty. +// The map is intended to map strings to pointers of ZSetValue objects. +// size member of the object is set to 0, indicating that the FZSet object is currently empty. // The skipList member of the object is set to a new SkipList object created by calling `newSkipList()` function. -func newZSetNodes() *ZSetNodes { - return &ZSetNodes{ - dict: make(map[string]*SkipListNodeValue), +func newZSetNodes() *FZSet { + return &FZSet{ + dict: make(map[string]*ZSetValue), size: 0, skipList: newSkipList(), } @@ -205,26 +205,26 @@ func newSkipListNode(level int, score int, key string, value interface{}) *SkipL return node } -// newSkipListNodeValue is a function that constructs and returns a new SkipListNodeValue. +// newSkipListNodeValue is a function that constructs and returns a new ZSetValue. // It takes a score (int), a key (string), and a value (interface{}) as parameters. -// These parameters serve as the initial state of the SkipListNodeValue upon its creation. -func newSkipListNodeValue(score int, member string, value interface{}) *SkipListNodeValue { - // Create a new instance of a SkipListNodeValue with the provided score, key, and value. - node := &SkipListNodeValue{ +// These parameters serve as the initial state of the ZSetValue upon its creation. +func newSkipListNodeValue(score int, member string, value interface{}) *ZSetValue { + // Create a new instance of a ZSetValue with the provided score, key, and value. + node := &ZSetValue{ score: score, member: member, value: value, } - // Return the newly created SkipListNodeValue. + // Return the newly created ZSetValue. return node } // insert is a method of the SkipList type that is used to insert a new node into the skip list. It takes as arguments -// the score (int), key (string) and a value (interface{}), and returns a pointer to the SkipListNodeValue struct. The method +// the score (int), key (string) and a value (interface{}), and returns a pointer to the ZSetValue struct. The method // organizes nodes in the list based on the score in ascending order. If two nodes have the same score, they will be arranged // based on the key value. The method also assigns span values to the levels in the skip list. -func (sl *SkipList) insert(score int, key string, value interface{}) *SkipListNodeValue { +func (sl *SkipList) insert(score int, key string, value interface{}) *ZSetValue { update := make([]*SkipListNode, SKIPLIST_MAX_LEVEL) rank := make([]int, SKIPLIST_MAX_LEVEL) node := sl.head @@ -326,7 +326,7 @@ func (sl *SkipList) delete(score int, member string) { sl.deleteNode(node, update) } } -func (sl *SkipList) getRange(start int, end int, reverse bool) (nv []SkipListNodeValue) { +func (sl *SkipList) getRange(start int, end int, reverse bool) (nv []ZSetValue) { if end > sl.size { end = sl.size - 1 } @@ -369,7 +369,6 @@ func (sl *SkipList) getRange(start int, end int, reverse bool) (nv []SkipListNod // - Then, it sets the pointers back to the previous node in the data structure and updates the tail and level of the whole list. // Finally, it decreases the size of the list by one, as a node is being removed from it. // It doesn't return any value and modifies the SkipList directly. - func (sl *SkipList) deleteNode(node *SkipListNode, updates []*SkipListNode) { for i := 0; i < sl.level; i++ { if updates[i].level[i].next == node { @@ -485,29 +484,56 @@ func (sl *SkipList) getNodeByRank(rank int) *SkipListNode { // // If the key is an empty string, an error will be returned func (zs *ZSetStructure) ZAdd(key string, score int, member string, value string) error { - if len(key) == 0 { - return _const.ErrKeyIsEmpty - } + return zs.ZAdds(key, []ZSetValue{{score: score, member: member, value: value}}...) +} +// ZAdds adds a value with its given score and member to a sorted set (ZSet), associated with +// the provided key. It is a method on the ZSetStructure type. +// +// Parameters: +// +// values: ...ZSetValue multiple values of ZSetValue. +func (zs *ZSetStructure) ZAdds(key string, vals ...ZSetValue) error { + if err := checkKey(key); err != nil { + return err + } zSet, err := zs.getOrCreateZSet(key) - if err != nil { return fmt.Errorf("failed to get or create ZSet from DB with key '%v': %w", key, err) } - // if values didn't change, do nothing - if zs.valuesDidntChange(zSet, score, member, value) { - return nil - } - - if err := zs.updateZSet(zSet, key, score, member, value); err != nil { - return fmt.Errorf("failed to set ZSet to DB with key '%v': %w", key, err) + for _, val := range vals { + // if values didn't change, do nothing + if zs.valuesDidntChange(zSet, val.score, val.member, val.value.(string)) { + continue + } + if err := zs.updateZSet(zSet, key, val.score, val.member, val.value.(string)); err != nil { + return fmt.Errorf("failed to set ZSet to DB with key '%v': %w", key, err) + } } - return nil + return zs.setZSetToDB(stringToBytesWithKey(key), zSet) } + +// exists checks if a given member with a specific score exists in a ZSet. It +// also verifies if the provided key is valid. The function returns a boolean +// value indicating whether the member with the specified score exists in the +// ZSet or not. +// +// Parameters: +// +// key (string): Specifies the key of the ZSet. +// score (int): The score of the member to be checked. +// member (string): The specific member to check for in the ZSet. +// +// Returns: +// +// bool: A boolean value indicating whether a member with the specified score +// +// exists in the ZSet or not. Returns false if the ZSet does not exist or if +// the key is invalid. func (zs *ZSetStructure) exists(key string, score int, member string) bool { - if len(key) == 0 { + if err := checkKey(key); err != nil { return false } keyBytes := stringToBytesWithKey(key) @@ -517,7 +543,6 @@ func (zs *ZSetStructure) exists(key string, score int, member string) bool { if err != nil { return false } - return zSet.exists(score, member) } @@ -537,8 +562,8 @@ the ZSet in the database. If any point of this operation fails, the function will return the corresponding error. */ func (zs *ZSetStructure) ZRem(key string, member string) error { - if len(key) == 0 { - return _const.ErrKeyIsEmpty + if err := checkKey(key); err != nil { + return err } keyBytes := stringToBytesWithKey(key) @@ -553,10 +578,37 @@ func (zs *ZSetStructure) ZRem(key string, member string) error { return zs.setZSetToDB(keyBytes, zSet) } +// ZRems method removes one or more specified members from the sorted set that's stored under the provided key. +// Params: +// - key string: the identifier for storing the sorted set in the database. +// - member ...string: a variadic parameter where each argument is a member string to remove. +// +// Returns: error +// +// The function will return an error if it fails at any point, if not it will return nil indicating a successful operation. +func (zs *ZSetStructure) ZRems(key string, member ...string) error { + if err := checkKey(key); err != nil { + return err + } + keyBytes := stringToBytesWithKey(key) + + zSet, err := zs.getZSetFromDB(keyBytes) + + if err != nil { + return fmt.Errorf("failed to get or create ZSet from DB with key '%v': %w", key, err) + } + for _, s := range member { + if err = zSet.RemoveNode(s); err != nil { + return err + } + } + return zs.setZSetToDB(keyBytes, zSet) +} + // ZScore method retrieves the score associated with the member in a sorted set stored at the key func (zs *ZSetStructure) ZScore(key string, member string) (int, error) { - if len(key) == 0 { - return 0, _const.ErrKeyIsEmpty + if err := checkKey(key); err != nil { + return 0, err } keyBytes := stringToBytesWithKey(key) @@ -602,8 +654,8 @@ rank, err := zs.ZRank("myKey", "memberName") fmt.Printf("The rank of '%s' in the set '%s' is %d\n", "memberName", "myKey", rank) */ func (zs *ZSetStructure) ZRank(key string, member string) (int, error) { - if len(key) == 0 { - return 0, _const.ErrKeyIsEmpty + if err := checkKey(key); err != nil { + return 0, err } keyBytes := stringToBytesWithKey(key) @@ -638,8 +690,8 @@ func (zs *ZSetStructure) ZRank(key string, member string) (int, error) { // // Note: The reverse rank is calculated as 'size - rank', and the ranks start from 1. func (zs *ZSetStructure) ZRevRank(key string, member string) (int, error) { - if len(key) == 0 { - return 0, _const.ErrKeyIsEmpty + if err := checkKey(key); err != nil { + return 0, err } keyBytes := stringToBytesWithKey(key) @@ -657,7 +709,7 @@ func (zs *ZSetStructure) ZRevRank(key string, member string) (int, error) { } // ZRange retrieves a specific range of elements from a sorted set (ZSet) denoted by a specific key. -// It returns a slice of SkipListNodeValue containing the elements within the specified range (inclusive), and a nil error when successful. +// It returns a slice of ZSetValue containing the elements within the specified range (inclusive), and a nil error when successful. // // The order of the returned elements is based on their rank in the set, not their score. // @@ -669,8 +721,8 @@ func (zs *ZSetStructure) ZRevRank(key string, member string) (int, error) { // // Returns: // -// []SkipListNodeValue: -// Slice of SkipListNodeValue containing elements within the specified range. +// []ZSetValue: +// Slice of ZSetValue containing elements within the specified range. // error: // An error if it occurs during execution, such as: // 1. The provided key string is empty. @@ -686,9 +738,9 @@ func (zs *ZSetStructure) ZRevRank(key string, member string) (int, error) { // ZRange("someKey", 0, 2) will return ["element1", "element2", "element3"] and nil error. // // This method is part of the ZSetStructure type. -func (zs *ZSetStructure) ZRange(key string, start int, end int) ([]SkipListNodeValue, error) { - if len(key) == 0 { - return nil, _const.ErrKeyIsEmpty +func (zs *ZSetStructure) ZRange(key string, start int, end int) ([]ZSetValue, error) { + if err := checkKey(key); err != nil { + return nil, err } keyBytes := stringToBytesWithKey(key) @@ -702,6 +754,60 @@ func (zs *ZSetStructure) ZRange(key string, start int, end int) ([]SkipListNodeV return r, nil } +// ZCount traverses through the elements of the ZSetStructure based on the given key. +// The count of elements between the range of min and max scores is determined. +// +// The method takes a string as the key and two integers as min and max ranges. +// The range values are inclusive: [min, max]. If min is greater than max, an error is returned. +// The function ignores scores that fall out of the specified min and max range. +// +// It returns the count of elements within the range, and an error if any occurs during the process. +// +// For example, use as follows: +// count, err := zs.ZCount("exampleKey", 10, 50) +// This will count the number of elements that have the scores between 10 and 50 in the ZSetStructure associated with "exampleKey". +// +// Returns: +// 1. int: The total count of elements based on the score range. +// 2. error: Errors that occurred during execution, if any. +func (zs *ZSetStructure) ZCount(key string, min int, max int) (count int, err error) { + if err = checkKey(key); err != nil { + return 0, err + } + keyBytes := stringToBytesWithKey(key) + zSet, err := zs.getZSetFromDB(keyBytes) + if err != nil { + return 0, err + } + if min > max { + return 0, ErrInvalidArgs + } + min, max, err = zs.adjustMinMax(zSet, min, max) + if err != nil { + return 0, err + } + x := zSet.skipList.head + // Node traversal loop. We keep moving to the next node at current level + // as long as the score of the next node's value is less than 'min'. + for i := zSet.skipList.level - 1; i >= 0; i-- { + for x.level[i].next != nil && x.level[i].next.value.score < min { + x = x.level[i].next + } + } + + x = x.level[0].next + // Score range check loop. We traverse nodes and increment 'count' + // as long as node value's score is in the range ['min', 'max'] + for x != nil { + if x.value.score > max { + break + } + count++ + x = x.level[0].next + } + return count, nil +} + // ZRevRange retrieves a range of elements from a sorted set (ZSet) in descending order. // Inputs: // - key: Name of the ZSet @@ -709,12 +815,12 @@ func (zs *ZSetStructure) ZRange(key string, start int, end int) ([]SkipListNodeV // - endRank: Final rank of the desired range // // Output: -// - An array of SkipListNodeValue, representing elements from the range [startRank, endRank] in descending order +// - An array of ZSetValue, representing elements from the range [startRank, endRank] in descending order // - Error if an issue occurs, such as when the key is empty or ZSet retrieval fails // error -func (zs *ZSetStructure) ZRevRange(key string, startRank int, endRank int) ([]SkipListNodeValue, error) { - if len(key) == 0 { - return nil, _const.ErrKeyIsEmpty +func (zs *ZSetStructure) ZRevRange(key string, startRank int, endRank int) ([]ZSetValue, error) { + if err := checkKey(key); err != nil { + return nil, err } keyBytes := stringToBytesWithKey(key) @@ -731,8 +837,8 @@ func (zs *ZSetStructure) ZRevRange(key string, startRank int, endRank int) ([]Sk // The ZCard function returns the size of the dictionary of the sorted set stored at key in the database. // It takes a string key as an argument. func (zs *ZSetStructure) ZCard(key string) (int, error) { - if len(key) == 0 { - return 0, _const.ErrKeyIsEmpty + if err := checkKey(key); err != nil { + return 0, err } keyBytes := stringToBytesWithKey(key) @@ -760,10 +866,9 @@ func (zs *ZSetStructure) ZCard(key string) (int, error) { // if there's an issue with node insertion, // if unable to set ZSet to DB post increment operation func (zs *ZSetStructure) ZIncrBy(key string, member string, incBy int) error { - if len(key) == 0 { - return _const.ErrKeyIsEmpty + if err := checkKey(key); err != nil { + return err } - keyBytes := stringToBytesWithKey(key) zSet, err := zs.getZSetFromDB(keyBytes) @@ -778,14 +883,14 @@ func (zs *ZSetStructure) ZIncrBy(key string, member string, incBy int) error { if err = zs.setZSetToDB(keyBytes, zSet); err != nil { return err } - return nil + return zs.setZSetToDB(keyBytes, zSet) } return _const.ErrKeyNotFound } // getOrCreateZSet attempts to retrieve a sorted set by a key, or creates a new one if it doesn't exist. -func (zs *ZSetStructure) getOrCreateZSet(key string) (*ZSetNodes, error) { +func (zs *ZSetStructure) getOrCreateZSet(key string) (*FZSet, error) { keyBytes := stringToBytesWithKey(key) zSet, err := zs.getZSetFromDB(keyBytes) // if key is not in the DB, create it. @@ -797,7 +902,7 @@ func (zs *ZSetStructure) getOrCreateZSet(key string) (*ZSetNodes, error) { } // valuesDidntChange checks if the data of a specific member in a sorted set remained the same. -func (zs *ZSetStructure) valuesDidntChange(zSet *ZSetNodes, score int, member string, value string) bool { +func (zs *ZSetStructure) valuesDidntChange(zSet *FZSet, score int, member string, value string) bool { if v, ok := zSet.dict[member]; ok { return v.score == score && v.member == member && v.value == value } @@ -806,15 +911,11 @@ func (zs *ZSetStructure) valuesDidntChange(zSet *ZSetNodes, score int, member st } // updateZSet updates or inserts a member in a sorted set and saves the change in storage. -func (zs *ZSetStructure) updateZSet(zSet *ZSetNodes, key string, score int, member string, value string) error { - if err := zSet.InsertNode(score, member, value); err != nil { - return err - } - - return zs.setZSetToDB(stringToBytesWithKey(key), zSet) +func (zs *ZSetStructure) updateZSet(zSet *FZSet, key string, score int, member string, value string) error { + return zSet.InsertNode(score, member, value) } -// InsertNode is a method on the ZSetNodes structure. It inserts a new node +// InsertNode is a method on the FZSet structure. It inserts a new node // or updates an existing node in the skip list and the dictionary. // It takes three parameters: score (an integer), key (a string), // and value (of any interface type). @@ -828,42 +929,77 @@ func (zs *ZSetStructure) updateZSet(zSet *ZSetNodes, key string, score int, memb // If the key doesn't exist in the dictionary, it adds the new key, value and score // to the dictionary, increments the size of the dictionary by 1, and also adds // the node to the skip list. -func (pq *ZSetNodes) InsertNode(score int, member string, value interface{}) error { +func (fzs *FZSet) InsertNode(score int, member string, value interface{}) error { // Instantiate dictionary if it's not already - if pq.dict == nil { - pq.dict = make(map[string]*SkipListNodeValue) + if fzs.dict == nil { + fzs.dict = make(map[string]*ZSetValue) } - if pq.skipList == nil { - pq.skipList = newSkipList() + if fzs.skipList == nil { + fzs.skipList = newSkipList() } // Check if key exists in dictionary - if v, ok := pq.dict[member]; ok { + if v, ok := fzs.dict[member]; ok { if v.score != score { // Update value and score as the score remains the same - pq.skipList.delete(score, member) - pq.dict[member] = pq.skipList.insert(score, member, value) + fzs.skipList.delete(score, member) + fzs.dict[member] = fzs.skipList.insert(score, member, value) } else { // Ranking isn't altered, only update value v.value = value } } else { // Key doesn't exist, create new key - pq.dict[member] = pq.skipList.insert(score, member, value) - pq.size++ // Increase size count by 1 + fzs.dict[member] = fzs.skipList.insert(score, member, value) + fzs.size++ // Increase size count by 1 // Node is also added to the skip list } // Returns nil as no specific error condition is checked in this function return nil } +func (zs *ZSetStructure) adjustMinMax(zSet *FZSet, min int, max int) (adjustedMin int, adjustedMax int, err error) { + if min > max { + return min, max, ErrInvalidArgs + } + minScore, maxScore := zSet.getMinMaxScore() + return zSet.max(min, minScore), zSet.min(max, maxScore), nil +} +func (fzs *FZSet) getMinMaxScore() (minScore int, maxScore int) { + if fzs == nil || fzs.skipList == nil || fzs.skipList.head == nil || len(fzs.skipList.head.level) < 1 || fzs.skipList.head.level[0].next == nil || fzs.skipList.tail == nil { + return 0, 0 + } + + if fzs.skipList.head.level[0].next.value == nil || fzs.skipList.tail.value == nil { + return 0, 0 + } + return fzs.skipList.head.level[0].next.value.score, + fzs.skipList.tail.value.score +} +func (fzs *FZSet) min(a, b int) int { + if a < b { + return a + } + return b +} + +func (fzs *FZSet) max(a, b int) int { + if a > b { + return a + } + return b +} +func getMinMaxScore(zSet *FZSet) (minScore int, maxScore int) { + return zSet.skipList.head.level[0].next.value.score, + zSet.skipList.tail.value.score +} -// RemoveNode is a method for ZSetNodes structure. +// RemoveNode is a method for FZSet structure. // This method aims to delete a node from both // the dictionary (dict) and the skip list (skipList). // // The method receives one parameter: // - member: a string that represents the key of the node -// to be removed from the ZSetNodes structure. +// to be removed from the FZSet structure. // // The method follows these steps: // 1. Check if a node with key 'member' exists in the dictionary. @@ -875,52 +1011,52 @@ func (pq *ZSetNodes) InsertNode(score int, member string, value interface{}) err // the success of the operation. // // The RemoveNode's primary purpose is to provide a way to securely and -// efficiently remove a node from the ZSetNodes structure. -func (pq *ZSetNodes) RemoveNode(member string) error { +// efficiently remove a node from the FZSet structure. +func (fzs *FZSet) RemoveNode(member string) error { // Check for existence of key in dictionary - v, ok := pq.dict[member] - if !ok || pq.dict == nil { + v, ok := fzs.dict[member] + if !ok || fzs.dict == nil { return _const.ErrKeyNotFound } // Delete Node from the skip list and dictionary - pq.skipList.delete(v.score, member) - delete(pq.dict, member) - pq.size-- + fzs.skipList.delete(v.score, member) + delete(fzs.dict, member) + fzs.size-- return nil } -func (pq *ZSetNodes) exists(score int, member string) bool { - v, ok := pq.dict[member] +func (fzs *FZSet) exists(score int, member string) bool { + v, ok := fzs.dict[member] return ok && v.score == score } -// Bytes encodes the ZSetNodes instance into bytes using MessagePack +// Bytes encodes the FZSet instance into bytes using MessagePack // binary serialization format. The encoded bytes can be used for // storage or transmission. If the encoding operation fails, an // error is returned. -func (pq *ZSetNodes) Bytes() ([]byte, error) { +func (fzs *FZSet) Bytes() ([]byte, error) { var msgPack = encoding.NewMessagePackEncoder() - if encodingError := msgPack.Encode(pq); encodingError != nil { + if encodingError := msgPack.Encode(fzs); encodingError != nil { return nil, encodingError } return msgPack.Bytes(), nil } -// FromBytes decodes the input byte slice into the ZSetNodes object using MessagePack. +// FromBytes decodes the input byte slice into the FZSet object using MessagePack. // Returns an error if decoding fails, otherwise nil. -func (pq *ZSetNodes) FromBytes(b []byte) error { - return encoding.NewMessagePackDecoder(b).Decode(pq) +func (fzs *FZSet) FromBytes(b []byte) error { + return encoding.NewMessagePackDecoder(b).Decode(fzs) } -// getZSetFromDB fetches and deserializes ZSetNodes from the database. +// getZSetFromDB fetches and deserializes FZSet from the database. // -// Returns a pointer to the ZSetNodes and error, if any. +// Returns a pointer to the FZSet and error, if any. // If the key doesn't exist, both the pointer and the error will be nil. // In case of deserialization errors, returns nil and the error. -func (zs *ZSetStructure) getZSetFromDB(key []byte) (*ZSetNodes, error) { +func (zs *ZSetStructure) getZSetFromDB(key []byte) (*FZSet, error) { dbData, err := zs.db.Get(key) // If key is not found, return nil for both; otherwise return the error. @@ -930,25 +1066,44 @@ func (zs *ZSetStructure) getZSetFromDB(key []byte) (*ZSetNodes, error) { } dec := encoding.NewMessagePackDecoder(dbData) // Deserialize the data. - var zSetValue ZSetNodes + var zSetValue FZSet if err = dec.Decode(&zSetValue); err != nil { return nil, err } - // return a pointer to the deserialized ZSetNodes, nil for the error + // return a pointer to the deserialized FZSet, nil for the error return &zSetValue, nil } -// setZSetToDB writes a ZSetNodes object to the database. +// checkKey function that accepts a string parameter key +// and returns error if key is empty. +// +// # It returns nil otherwise +// +// Parameters: +// +// key : A string that is checked if empty +// +// Returns: +// +// error : _const.ErrKeyIsEmpty if key is empty, nil otherwise +func checkKey(key string) error { + if len(key) == 0 { + return _const.ErrKeyIsEmpty + } + return nil +} + +// setZSetToDB writes a FZSet object to the database. // // parameters: // key: This is a byte slice that is used as a key in the database. -// zSetValue: This is a pointer to a ZSetNodes object that needs to be stored in the database. +// zSetValue: This is a pointer to a FZSet object that needs to be stored in the database. // -// The function serializes the ZSetNodes object into MessagePack format. If an error occurs +// The function serializes the FZSet object into MessagePack format. If an error occurs // either during serialization or when writing to the database, that specific error is returned. // If the process is successful, it returns nil. -func (zs *ZSetStructure) setZSetToDB(key []byte, zSetValue *ZSetNodes) error { +func (zs *ZSetStructure) setZSetToDB(key []byte, zSetValue *FZSet) error { val := encoding.NewMessagePackEncoder() err := val.Encode(zSetValue) if err != nil { @@ -957,7 +1112,7 @@ func (zs *ZSetStructure) setZSetToDB(key []byte, zSetValue *ZSetNodes) error { return zs.db.Put(key, val.Bytes()) } -// UnmarshalBinary de-serializes the given byte slice into ZSetNodes instance +// UnmarshalBinary de-serializes the given byte slice into FZSet instance // it uses MessagePack format for de-serialization // Returns an error if the decoding of size or insertion of node fails. // @@ -966,7 +1121,7 @@ func (zs *ZSetStructure) setZSetToDB(key []byte, zSetValue *ZSetNodes) error { // // Returns: // An error that will be nil if the function succeeds. -func (p *ZSetNodes) UnmarshalBinary(data []byte) (err error) { +func (fzs *FZSet) UnmarshalBinary(data []byte) (err error) { // NewMessagePackDecoder creates a new MessagePack decoder with the provided data dec := encoding.NewMessagePackDecoder(data) @@ -978,33 +1133,33 @@ func (p *ZSetNodes) UnmarshalBinary(data []byte) (err error) { // Iterate through each node in the data structure for i := 0; i < size; i++ { - // Create an empty instance of SkipListNodeValue for each node - slValue := SkipListNodeValue{} + // Create an empty instance of ZSetValue for each node + slValue := ZSetValue{} - // Decode each node onto the empty SkipListNodeValue instance + // Decode each node onto the empty ZSetValue instance if err = dec.Decode(&slValue); err != nil { return err // error handling if something goes wrong with decoding } - // Insert the decoded node into the ZSetNodes instance - if err = p.InsertNode(slValue.score, slValue.member, slValue.value); err != nil { + // Insert the decoded node into the FZSet instance + if err = fzs.InsertNode(slValue.score, slValue.member, slValue.value); err != nil { return err } } return // if all nodes are correctly decoded and inserted, return with nil error } -// MarshalBinary serializes the ZSetNodes instance into a byte slice. +// MarshalBinary serializes the FZSet instance into a byte slice. // It uses MessagePack format for serialization // Returns the serialized byte slice and an error if the encoding fails. -func (d *ZSetNodes) MarshalBinary() (_ []byte, err error) { +func (fzs *FZSet) MarshalBinary() (_ []byte, err error) { // Initializing the MessagePackEncoder enc := encoding.NewMessagePackEncoder() // Encoding the size attribute of d (i.e., d.size). The operation could fail, thus we check for an error. // An error, if occurred, will be returned immediately, hence the flow of execution stops here. - err = enc.Encode(d.size) + err = enc.Encode(fzs.size) if err != nil { return nil, err } @@ -1016,7 +1171,7 @@ func (d *ZSetNodes) MarshalBinary() (_ []byte, err error) { // we do that to get the elements in reverse order from biggest to the smallest for the best // insertion efficiency as it makes the insertion O(1), because each new element to be inserted is // the smallest yet. - x := d.skipList.tail + x := fzs.skipList.tail // as long as there are elements in the SkipList continue for x != nil { // Encoding the value of the current node in the skip list @@ -1031,15 +1186,15 @@ func (d *ZSetNodes) MarshalBinary() (_ []byte, err error) { } // After the traversal of the skip list, the encoder should now hold the serialized representation of the - // ZSetNodes. Now, we return the bytes from the encoder along with any error that might have occurred + // FZSet. Now, we return the bytes from the encoder along with any error that might have occurred // during the encoding (should be nil if everything went fine). return enc.Bytes(), err } -// UnmarshalBinary de-serializes the given byte slice into SkipListNodeValue instance +// UnmarshalBinary de-serializes the given byte slice into ZSetValue instance // It uses the MessagePack format for de-serialization // Returns an error if the decoding of Key, Score, or Value fails. -func (p *SkipListNodeValue) UnmarshalBinary(data []byte) (err error) { +func (p *ZSetValue) UnmarshalBinary(data []byte) (err error) { dec := encoding.NewMessagePackDecoder(data) if err = dec.Decode(&p.member); err != nil { return @@ -1054,14 +1209,14 @@ func (p *SkipListNodeValue) UnmarshalBinary(data []byte) (err error) { } // MarshalBinary uses MessagePack as the encoding format to serialize -// the SkipListNodeValue object into a byte array. -func (d *SkipListNodeValue) MarshalBinary() (_ []byte, err error) { +// the ZSetValue object into a byte array. +func (d *ZSetValue) MarshalBinary() (_ []byte, err error) { // The NewMessagePackEncoder function is called to create a new // MessagePack encoder. enc := encoding.NewMessagePackEncoder() - // Then, we try to encode the 'key' field of the SkipListNodeValue + // Then, we try to encode the 'key' field of the ZSetValue // If an error occurs, it is returned immediately along with the // currently encoded byte slice. if err = enc.Encode(d.member); err != nil { @@ -1079,7 +1234,7 @@ func (d *SkipListNodeValue) MarshalBinary() (_ []byte, err error) { } // If everything goes well and we're done encoding, we return the - // final byte slice which represents the encoded SkipListNodeValue + // final byte slice which represents the encoded ZSetValue // and a nil error. return enc.Bytes(), err } diff --git a/structure/zset_test.go b/structure/zset_test.go index 87d0d8ae..5bb2a629 100644 --- a/structure/zset_test.go +++ b/structure/zset_test.go @@ -5,7 +5,6 @@ import ( _const "github.com/ByteStorage/FlyDB/lib/const" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "math/rand" "os" "reflect" "testing" @@ -23,7 +22,7 @@ func TestSortedSet(t *testing.T) { type test struct { name string input map[string]int - want *ZSetNodes + want *FZSet expectError bool } @@ -45,7 +44,7 @@ func TestSortedSet(t *testing.T) { { name: "empty", input: map[string]int{}, - want: &ZSetNodes{}, + want: &FZSet{}, expectError: false, }, { @@ -100,6 +99,90 @@ func TestZRem(t *testing.T) { }) } } +func TestZRems(t *testing.T) { + mockZSetStructure, _ := initZSetDB() + + // 1. Test for Key is Empty + err := mockZSetStructure.ZRems("", "member") + require.Error(t, err) + require.Equal(t, _const.ErrKeyIsEmpty, err) + type testCase struct { + key string + input []ZSetValue + rems []string + want []ZSetValue + err error + } + + testCases := []testCase{ + {"key", + []ZSetValue{ + {score: 0, member: "mem0", value: ""}, + {score: 1, member: "mem1", value: ""}, + {score: 2, member: "mem2", value: ""}, + {score: 3, member: "mem3", value: ""}, + {score: 4, member: "mem4", value: ""}, + {score: 5, member: "mem5", value: ""}, + {score: 6, member: "mem6", value: ""}, + }, + []string{ + "mem0", + "mem1", + "mem6", + }, []ZSetValue{ + {score: 2, member: "mem2", value: ""}, + {score: 3, member: "mem3", value: ""}, + {score: 4, member: "mem4", value: ""}, + {score: 5, member: "mem5", value: ""}, + }, nil}, + {"", + []ZSetValue{ + {score: 0, member: "mem0", value: ""}, + {score: 1, member: "mem1", value: ""}, + }, + []string{ + "mem0", + "mem1", + "mem2", + "mem3", + "mem4", + "mem5", + "mem6", + }, + []ZSetValue{}, + _const.ErrKeyIsEmpty}, + { + "Key1", + []ZSetValue{ + {score: 0, member: "mem0", value: ""}, + {score: 1, member: "mem1", value: ""}, + }, + []string{ + "mem0", + "mem1", + "mem2", + "mem3", + "mem4", + "mem5", + "mem6", + }, []ZSetValue{}, _const.ErrKeyNotFound}, + } + + for _, tc := range testCases { + t.Run(tc.key, func(t *testing.T) { + _ = mockZSetStructure.ZAdds(tc.key, tc.input...) + + //remove all the elements + err = mockZSetStructure.ZRems(tc.key, tc.rems...) + assert.Equal(t, tc.err, err) + //validate + for _, value := range tc.want { + te := mockZSetStructure.exists(tc.key, value.score, value.member) + assert.True(t, te) + } + }) + } +} func TestZAdd(t *testing.T) { zs, _ := initZSetDB() type testCase struct { @@ -107,7 +190,7 @@ func TestZAdd(t *testing.T) { score int member string value string - want SkipListNodeValue + want ZSetValue err error } @@ -117,7 +200,7 @@ func TestZAdd(t *testing.T) { 10, "member", "value", - SkipListNodeValue{member: "member"}, + ZSetValue{member: "member"}, nil, }, { @@ -125,7 +208,7 @@ func TestZAdd(t *testing.T) { 10, "member", "value", - SkipListNodeValue{member: ""}, + ZSetValue{member: ""}, _const.ErrKeyIsEmpty, }, } @@ -147,6 +230,65 @@ func TestZAdd(t *testing.T) { }) } } +func TestZAdds(t *testing.T) { + zs, _ := initZSetDB() + + // 1. Test for Key is Empty + err := zs.ZAdds("", []ZSetValue{}...) + require.Error(t, err) + require.Equal(t, _const.ErrKeyIsEmpty, err) + type testCase struct { + key string + input []ZSetValue + want []ZSetValue + err error + } + + testCases := []testCase{ + {"key", + []ZSetValue{ + {score: 0, member: "mem0", value: ""}, + {score: 1, member: "mem1", value: ""}, + {score: 2, member: "mem2", value: ""}, + {score: 3, member: "mem3", value: ""}, + {score: 3, member: "mem3", value: ""}, + {score: 4, member: "mem4", value: ""}, + {score: 5, member: "mem5", value: ""}, + {score: 6, member: "mem6", value: ""}, + }, + []ZSetValue{ + {score: 0, member: "mem0", value: ""}, + {score: 1, member: "mem1", value: ""}, + {score: 2, member: "mem2", value: ""}, + {score: 3, member: "mem3", value: ""}, + {score: 3, member: "mem3", value: ""}, + {score: 4, member: "mem4", value: ""}, + {score: 5, member: "mem5", value: ""}, + {score: 6, member: "mem6", value: ""}, + }, + nil}, + {"", + []ZSetValue{ + {score: 0, member: "mem0", value: ""}, + {score: 1, member: "mem1", value: ""}, + }, + []ZSetValue{}, + _const.ErrKeyIsEmpty, + }, + } + + for _, tc := range testCases { + t.Run(tc.key, func(t *testing.T) { + err = zs.ZAdds(tc.key, tc.input...) + assert.Equal(t, tc.err, err) + //validate + for _, value := range tc.want { + te := zs.exists(tc.key, value.score, value.member) + assert.True(t, te) + } + }) + } +} func TestZIncrBy(t *testing.T) { zs, _ := initZSetDB() err := zs.ZIncrBy("", "non-existingMember", 5) @@ -290,10 +432,10 @@ func TestZRevRange(t *testing.T) { key string start int end int - want []SkipListNodeValue + want []ZSetValue wantErr error }{ - {"myKey", 0, 3, []SkipListNodeValue{ + {"myKey", 0, 3, []ZSetValue{ {6, "member6", n}, {5, "member5", n}, {4, "member4", n}, @@ -341,10 +483,10 @@ func TestZRange(t *testing.T) { key string start int end int - want []SkipListNodeValue + want []ZSetValue wantErr error }{ - {"myKey", 0, 3, []SkipListNodeValue{ + {"myKey", 0, 3, []ZSetValue{ {1, "member1", n}, {2, "member2", n}, {3, "member3", n}, @@ -547,15 +689,13 @@ type testZSetNodeValue struct { value interface{} } -func populateSkipListFromSlice(nodes *ZSetNodes, zSetNodeValues []testZSetNodeValue) { +func populateSkipListFromSlice(nodes *FZSet, zSetNodeValues []testZSetNodeValue) { // Iterate over the zsetNodes array for _, zSetNode := range zSetNodeValues { _ = nodes.InsertNode(zSetNode.score, zSetNode.member, zSetNode.value) } } func TestRandomLevel(t *testing.T) { - rand.Seed(1) - for i := 0; i < 1000; i++ { level := randomLevel() if level < 1 || level > SKIPLIST_MAX_LEVEL { @@ -564,7 +704,7 @@ func TestRandomLevel(t *testing.T) { } } func TestZSetNodes_InsertNode(t *testing.T) { - pq := &ZSetNodes{} + pq := &FZSet{} // Case 1: Insert new node err := pq.InsertNode(1, "test", "value") @@ -596,3 +736,154 @@ func TestZSetNodes_InsertNode(t *testing.T) { t.Error("Update node failed, expected score to be updated") } } +func TestZCount(t *testing.T) { + zs, _ := initZSetDB() + + tests := []struct { + key string + input []testZSetNodeValue + min int + max int + want int + err error + }{ + { + "test1", + []testZSetNodeValue{ + {score: 0, member: "mem0", value: ""}, + {score: 1, member: "mem1", value: ""}, + {score: 2, member: "mem2", value: ""}, + {score: 3, member: "mem3", value: ""}, + {score: 4, member: "mem4", value: ""}, + {score: 5, member: "mem5", value: ""}, + {score: 6, member: "mem6", value: ""}, + }, + 1, 5, 5, nil, + }, + { + "test2", + []testZSetNodeValue{ + {score: 0, member: "mem0", value: ""}, + {score: 1, member: "mem1", value: ""}, + {score: 2, member: "mem2", value: ""}, + {score: 3, member: "mem3", value: ""}, + {score: 4, member: "mem4", value: ""}, + {score: 5, member: "mem5", value: ""}, + {score: 6, member: "mem6", value: ""}, + }, + 0, 5, 6, nil, + }, + { + "test3", + []testZSetNodeValue{ + {score: 0, member: "mem0", value: ""}, + {score: 1, member: "mem1", value: ""}, + {score: 2, member: "mem2", value: ""}, + {score: 3, member: "mem3", value: ""}, + {score: 4, member: "mem4", value: ""}, + {score: 5, member: "mem5", value: ""}, + {score: 6, member: "mem6", value: ""}, + }, + 1, 3, 3, nil, + }, + { + "test4", + []testZSetNodeValue{ + {score: 0, member: "mem0", value: ""}, + {score: 1, member: "mem1", value: ""}, + {score: 2, member: "mem2", value: ""}, + {score: 3, member: "mem3", value: ""}, + {score: 4, member: "mem4", value: ""}, + {score: 5, member: "mem5", value: ""}, + {score: 6, member: "mem6", value: ""}, + }, + 2, 2, 1, nil, + }, + { + "test5", + []testZSetNodeValue{ + {score: 3, member: "mem3", value: ""}, + }, + 10, 20, 0, nil, + }, + { + "test6", + []testZSetNodeValue{ + {score: 3, member: "mem3", value: ""}, + }, + 10, 5, 0, ErrInvalidArgs, + }, + { + "", + []testZSetNodeValue{ + {score: 3, member: "mem3", value: ""}, + }, + 10, 5, 0, _const.ErrKeyIsEmpty, + }, + } + + for _, tt := range tests { + t.Run(tt.key, func(t *testing.T) { + for _, value := range tt.input { + _ = zs.ZAdd(tt.key, value.score, value.member, value.value.(string)) + } + got, err := zs.ZCount(tt.key, tt.min, tt.max) + if got != tt.want { + t.Errorf("TestZCount(%v, %v, %v) = %v, want: %v", tt.key, tt.min, tt.max, got, tt.want) + } + if err != nil && err.Error() != tt.err.Error() { + t.Errorf("TestZCount(%v, %v, %v) returned unexpected error: got %v, want: %v", tt.key, tt.min, tt.max, err, tt.err) + } + }) + } +} + +func TestFZSetMinMax(t *testing.T) { + fzs := &FZSet{ + skipList: newSkipList(), + } + _ = fzs.InsertNode(1, "mem1", "") + _ = fzs.InsertNode(100, "mem2", "") + + minScore, maxScore := fzs.getMinMaxScore() + + if minScore != 1 || maxScore != 100 { + t.Errorf("getMinMaxScore() = %d, %d, want: 1, 100", minScore, maxScore) + } + + if min := fzs.min(5, 10); min != 5 { + t.Errorf("min(5, 10) = %d, want: 5", min) + } + + if max := fzs.max(5, 10); max != 10 { + t.Errorf("max(5, 10) = %d, want: 10", max) + } +} + +func Test_exists(t *testing.T) { + zs, _ := initZSetDB() + + tt := []struct { + key string + score int + member string + want bool + }{ + { + key: "", + score: 1, + member: "", + want: false, + }, + } + + for _, tc := range tt { + t.Run("", func(t *testing.T) { + got := zs.exists(tc.key, tc.score, tc.member) + + if got != tc.want { + t.Errorf("exists() = %v, want %v", got, tc.want) + } + }) + } +} From 09b3aea79067585a6547395a026528f6f0069901 Mon Sep 17 00:00:00 2001 From: Saeid Aghapour Date: Tue, 18 Jul 2023 19:11:20 +0330 Subject: [PATCH 5/5] additional unittests(#138) --- structure/zset.go | 15 ++-- structure/zset_test.go | 188 ++++++++++++++++++++++++++++++++++++----- 2 files changed, 174 insertions(+), 29 deletions(-) diff --git a/structure/zset.go b/structure/zset.go index 9f5d6052..04b03469 100644 --- a/structure/zset.go +++ b/structure/zset.go @@ -504,10 +504,10 @@ func (zs *ZSetStructure) ZAdds(key string, vals ...ZSetValue) error { for _, val := range vals { // if values didn't change, do nothing - if zs.valuesDidntChange(zSet, val.score, val.member, val.value.(string)) { + if zs.valuesDidntChange(zSet, val.score, val.member, val.value) { continue } - if err := zs.updateZSet(zSet, key, val.score, val.member, val.value.(string)); err != nil { + if err := zs.updateZSet(zSet, key, val.score, val.member, val.value); err != nil { return fmt.Errorf("failed to set ZSet to DB with key '%v': %w", key, err) } } @@ -568,9 +568,8 @@ func (zs *ZSetStructure) ZRem(key string, member string) error { keyBytes := stringToBytesWithKey(key) zSet, err := zs.getZSetFromDB(keyBytes) - if err != nil { - return fmt.Errorf("failed to get or create ZSet from DB with key '%v': %w", key, err) + return err } if err = zSet.RemoveNode(member); err != nil { return err @@ -902,7 +901,7 @@ func (zs *ZSetStructure) getOrCreateZSet(key string) (*FZSet, error) { } // valuesDidntChange checks if the data of a specific member in a sorted set remained the same. -func (zs *ZSetStructure) valuesDidntChange(zSet *FZSet, score int, member string, value string) bool { +func (zs *ZSetStructure) valuesDidntChange(zSet *FZSet, score int, member string, value interface{}) bool { if v, ok := zSet.dict[member]; ok { return v.score == score && v.member == member && v.value == value } @@ -911,7 +910,7 @@ func (zs *ZSetStructure) valuesDidntChange(zSet *FZSet, score int, member string } // updateZSet updates or inserts a member in a sorted set and saves the change in storage. -func (zs *ZSetStructure) updateZSet(zSet *FZSet, key string, score int, member string, value string) error { +func (zs *ZSetStructure) updateZSet(zSet *FZSet, key string, score int, member string, value interface{}) error { return zSet.InsertNode(score, member, value) } @@ -988,10 +987,6 @@ func (fzs *FZSet) max(a, b int) int { } return b } -func getMinMaxScore(zSet *FZSet) (minScore int, maxScore int) { - return zSet.skipList.head.level[0].next.value.score, - zSet.skipList.tail.value.score -} // RemoveNode is a method for FZSet structure. // This method aims to delete a node from both diff --git a/structure/zset_test.go b/structure/zset_test.go index 5bb2a629..8cc0a83f 100644 --- a/structure/zset_test.go +++ b/structure/zset_test.go @@ -14,8 +14,8 @@ func initZSetDB() (*ZSetStructure, *config.Options) { opts := config.DefaultOptions dir, _ := os.MkdirTemp("", "TestZSetStructure") opts.DirPath = dir - hash, _ := NewZSetStructure(opts) - return hash, &opts + zs, _ := NewZSetStructure(opts) + return zs, &opts } func TestSortedSet(t *testing.T) { @@ -67,33 +67,76 @@ func TestSortedSet_Bytes(t *testing.T) { } func TestZRem(t *testing.T) { - mockZSetStructure, _ := initZSetDB() - // 1. Test for Key is Empty - err := mockZSetStructure.ZRem("", "member") - require.Error(t, err) - require.Equal(t, _const.ErrKeyIsEmpty, err) type testCase struct { - key string - score int - member string - value string - err error + name string + key string + setup func(z *ZSetStructure) + members []string + want []string + dontWant []string + err error } testCases := []testCase{ - {"key", 10, "member", "value", nil}, - {"", 10, "member", "value", _const.ErrKeyIsEmpty}, + { + name: "key empty", + setup: func(z *ZSetStructure) { + _ = z.ZAdds("key1", []ZSetValue{{}}...) + }, + members: []string{""}, + err: _const.ErrKeyIsEmpty, + }, + { + name: "key not found", + setup: func(z *ZSetStructure) { + _ = z.ZAdds("key1", []ZSetValue{{}}...) + }, + key: "notfound", + members: []string{""}, + err: _const.ErrKeyNotFound, + }, + { + name: "member not found", + setup: func(z *ZSetStructure) { + _ = z.ZAdds("key1", []ZSetValue{{}}...) + }, + key: "key1", + members: []string{"notfound"}, + err: _const.ErrKeyNotFound, + }, + { + name: "member empty", + setup: func(z *ZSetStructure) { + _ = z.ZAdds("key1", []ZSetValue{{score: 1, member: "mem1", value: ""}}...) + }, + key: "key1", + members: []string{""}, + err: _const.ErrKeyNotFound, + }, + { + name: "remove half members", + setup: func(z *ZSetStructure) { + _ = z.ZAdds("key1", []ZSetValue{ + {score: 1, member: "mem1", value: ""}, + {score: 2, member: "mem2", value: ""}, + {score: 3, member: "mem3", value: ""}, + {score: 4, member: "mem4", value: ""}}...) + }, + }, } for _, tc := range testCases { - t.Run(tc.key, func(t *testing.T) { - err := mockZSetStructure.ZAdd(tc.key, tc.score, tc.member, tc.value) - // check to see if element added - assert.Equal(t, tc.err, err) + t.Run(tc.name, func(t *testing.T) { + mockZSetStructure, _ := initZSetDB() + tc.setup(mockZSetStructure) + + for _, m := range tc.members { + err = mockZSetStructure.ZRem(tc.key, m) + assert.EqualError(t, err, tc.err.Error()) + } if tc.err == nil { // check if member added - assert.True(t, mockZSetStructure.exists(tc.key, tc.score, tc.member)) } }) @@ -858,8 +901,76 @@ func TestFZSetMinMax(t *testing.T) { if max := fzs.max(5, 10); max != 10 { t.Errorf("max(5, 10) = %d, want: 10", max) } + + // if case part of skip list is missing, we should return 0,0 + fzs.skipList.tail.value = nil + minScore, maxScore = fzs.getMinMaxScore() + assert.Equal(t, minScore, 0) + assert.Equal(t, maxScore, 0) + // again, an error case + fzs.skipList = nil + minScore, maxScore = fzs.getMinMaxScore() + assert.Equal(t, minScore, 0) + assert.Equal(t, maxScore, 0) + +} +func TestZSetStructure_adjustMinMax(t *testing.T) { + zss, _ := NewZSetStructure(config.DefaultOptions) + fz := newZSetNodes() + + _, _, err := zss.adjustMinMax(fz, 100, 0) + assert.Equal(t, ErrInvalidArgs, err) + // + _ = fz.InsertNode(30, "mem1", "") + _ = fz.InsertNode(200, "mem1", "") + minScore, maxScore, err := zss.adjustMinMax(fz, 10, 50) + assert.NoError(t, err) + // as the min now is 30, our provided min of 10 will be turned into 30 + // as our param of max is 50 and maximum score is 200, it won't change + assert.Equal(t, 30, minScore) + assert.Equal(t, 50, maxScore) } +func TestZset_getNodeByRank(t *testing.T) { + sl := newSkipList() + sl.insert(1, "mem1", "") + sl.insert(2, "mem2", "") + sl.insert(3, "mem3", "") + tests := []struct { + name string + rank int + want *ZSetValue // Expected Output, use your actual SkipListNode instance or null here + }{ + { + name: "Case 1: Get Node by Rank 1", + rank: 1, + want: &ZSetValue{score: 1, member: "mem1", value: ""}, + }, + { + name: "Case 2: Get Node by Rank 2", + rank: 2, + want: &ZSetValue{score: 2, member: "mem2", value: ""}, + }, + { + name: "Case 3: Get Node by Non-existed Rank", + rank: 9999, + want: nil, // should return nil if rank doesn't exist + }, + } + + // Iterate over test cases + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := sl.getNodeByRank(tt.rank) + if tt.want == nil { + assert.Nil(t, got) + + } else { + assert.Equal(t, tt.want, got.value) + } + }) + } +} func Test_exists(t *testing.T) { zs, _ := initZSetDB() @@ -875,6 +986,12 @@ func Test_exists(t *testing.T) { member: "", want: false, }, + { + key: "key1", + score: 1, + member: "", + want: false, + }, } for _, tc := range tt { @@ -887,3 +1004,36 @@ func Test_exists(t *testing.T) { }) } } +func TestNewZSetStructure(t *testing.T) { + tt := []struct { + name string + setup func() (*ZSetStructure, error) + wantErr error + }{ + { + name: "init no error", + setup: func() (*ZSetStructure, error) { + opts := config.DefaultOptions + dir, _ := os.MkdirTemp("", "TestZSetStructure") + opts.DirPath = dir + return NewZSetStructure(opts) + }, + wantErr: nil, + }, + { + name: "init with error wrong path", + setup: func() (*ZSetStructure, error) { + opts := config.DefaultOptions + opts.DirPath = "" + return NewZSetStructure(opts) + }, + wantErr: _const.ErrOptionDirPathIsEmpty, + }, + } + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + _, err := tc.setup() + assert.Equal(t, tc.wantErr, err) + }) + } +}