From c6c1e5ec7690b5e5d7b47f6ab913bae6f78df03b Mon Sep 17 00:00:00 2001 From: Ibrahim Jarif Date: Fri, 13 Mar 2020 20:42:22 +0530 Subject: [PATCH] Add support for watching nil prefix in subscribe API (#1246) This PR adds support for watching empty prefixes (all keys) in subscribe API. To subscribe to all changes in badger, user can run ```go db.Subscribe(ctx, handler, nil) ``` or ```go db.Subscribe(ctx, handler, []byte{}) ``` --- db.go | 6 ++---- errors.go | 3 --- trie/trie.go | 7 +++++++ trie/trie_test.go | 27 +++++++++++++++++++++------ 4 files changed, 30 insertions(+), 13 deletions(-) diff --git a/db.go b/db.go index 1f0089d41..f7d8473e5 100644 --- a/db.go +++ b/db.go @@ -1580,9 +1580,7 @@ func (db *DB) Subscribe(ctx context.Context, cb func(kv *KVList) error, prefixes if cb == nil { return ErrNilCallback } - if len(prefixes) == 0 { - return ErrNoPrefixes - } + c := y.NewCloser(1) recvCh, id := db.pub.newSubscriber(c, prefixes...) slurp := func(batch *pb.KVList) error { @@ -1616,7 +1614,7 @@ func (db *DB) Subscribe(ctx context.Context, cb func(kv *KVList) error, prefixes err := slurp(batch) if err != nil { c.Done() - // Delete the subsriber if there is an error by the callback. + // Delete the subscriber if there is an error by the callback. db.pub.deleteSubscriber(id) return err } diff --git a/errors.go b/errors.go index bb9891e78..61dec9ba9 100644 --- a/errors.go +++ b/errors.go @@ -110,9 +110,6 @@ var ( // ErrNilCallback is returned when subscriber's callback is nil. ErrNilCallback = errors.New("Callback cannot be nil") - // ErrNoPrefixes is returned when subscriber doesn't provide any prefix. - ErrNoPrefixes = errors.New("At least one key prefix is required") - // ErrEncryptionKeyMismatch is returned when the storage key is not // matched with the key previously given. ErrEncryptionKeyMismatch = errors.New("Encryption key mismatch") diff --git a/trie/trie.go b/trie/trie.go index f856869c5..98e4a9dcb 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -59,6 +59,13 @@ func (t *Trie) Add(prefix []byte, id uint64) { func (t *Trie) Get(key []byte) map[uint64]struct{} { out := make(map[uint64]struct{}) node := t.root + // If root has ids that means we have subscribers for "nil/[]byte{}" + // prefix. Add them to the list. + if len(node.ids) > 0 { + for _, i := range node.ids { + out[i] = struct{}{} + } + } for _, val := range key { child, ok := node.children[val] if !ok { diff --git a/trie/trie_test.go b/trie/trie_test.go index 31b4854a0..ac02374d7 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -30,16 +30,25 @@ func TestGet(t *testing.T) { trie.Add([]byte("hel"), 20) trie.Add([]byte("he"), 20) trie.Add([]byte("badger"), 30) + + trie.Add(nil, 10) + require.Equal(t, map[uint64]struct{}{10: {}}, trie.Get([]byte("A"))) + ids := trie.Get([]byte("hel")) - require.Equal(t, 1, len(ids)) + require.Equal(t, 2, len(ids)) + require.Equal(t, map[uint64]struct{}{10: {}, 20: {}}, ids) - require.Equal(t, map[uint64]struct{}{20: {}}, ids) ids = trie.Get([]byte("badger")) - require.Equal(t, 1, len(ids)) - require.Equal(t, map[uint64]struct{}{30: {}}, ids) + require.Equal(t, 2, len(ids)) + require.Equal(t, map[uint64]struct{}{10: {}, 30: {}}, ids) + ids = trie.Get([]byte("hello")) - require.Equal(t, 4, len(ids)) - require.Equal(t, map[uint64]struct{}{1: {}, 3: {}, 4: {}, 20: {}}, ids) + require.Equal(t, 5, len(ids)) + require.Equal(t, map[uint64]struct{}{10: {}, 1: {}, 3: {}, 4: {}, 20: {}}, ids) + + trie.Add([]byte{}, 11) + require.Equal(t, map[uint64]struct{}{10: {}, 11: {}}, trie.Get([]byte("A"))) + } func TestTrieDelete(t *testing.T) { @@ -47,6 +56,12 @@ func TestTrieDelete(t *testing.T) { trie.Add([]byte("hello"), 1) trie.Add([]byte("hello"), 3) trie.Add([]byte("hello"), 4) + trie.Add(nil, 5) + trie.Delete([]byte("hello"), 4) + + require.Equal(t, map[uint64]struct{}{5: {}, 1: {}, 3: {}}, trie.Get([]byte("hello"))) + + trie.Delete(nil, 5) require.Equal(t, map[uint64]struct{}{1: {}, 3: {}}, trie.Get([]byte("hello"))) }