diff --git a/lib/backend/backend.go b/lib/backend/backend.go index 7506a605fd401..05453be5d04ec 100644 --- a/lib/backend/backend.go +++ b/lib/backend/backend.go @@ -189,18 +189,21 @@ func IterateRange(ctx context.Context, bk Backend, startKey, endKey Key, limit i // 2. allow individual backends to expose custom streaming methods s.t. the most performant // impl for a given backend may be used. func StreamRange(ctx context.Context, bk Backend, startKey, endKey Key, pageSize int) stream.Stream[Item] { - return stream.PageFunc[Item](func() ([]Item, error) { - if startKey.components == nil { + var done bool + return stream.PageFunc(func() ([]Item, error) { + if done { return nil, io.EOF } - rslt, err := bk.GetRange(ctx, startKey, endKey, pageSize) + rslt, err := bk.GetRange(ctx, startKey, endKey, pageSize+1) if err != nil { return nil, trace.Wrap(err) } - if len(rslt.Items) < pageSize { - startKey = Key{} + if len(rslt.Items) > pageSize { + startKey = rslt.Items[pageSize].Key + clear(rslt.Items[pageSize:]) + rslt.Items = rslt.Items[:pageSize] } else { - startKey = nextKey(rslt.Items[pageSize-1].Key) + done = true } return rslt.Items, nil }) @@ -320,10 +323,10 @@ func (p Params) GetString(key string) string { // NoLimit specifies no limits const NoLimit = 0 -// nextKey returns the next possible key. -// If used with a key prefix, this will return -// the end of the range for that key prefix. -func nextKey(key Key) Key { +const noEnd = "\x00" + +// RangeEnd returns end of the range for given key. +func RangeEnd(key Key) Key { end := make([]byte, len(key.s)) copy(end, key.s) for i := len(end) - 1; i >= 0; i-- { @@ -337,13 +340,6 @@ func nextKey(key Key) Key { return Key{noEnd: true} } -var noEnd = []byte{0} - -// RangeEnd returns end of the range for given key. -func RangeEnd(key Key) Key { - return nextKey(key) -} - // HostID is a derivation of a KeyedItem that allows the host id // to be included in the key. type HostID interface { diff --git a/lib/backend/key.go b/lib/backend/key.go index 881704be2a254..ea79724e98eb5 100644 --- a/lib/backend/key.go +++ b/lib/backend/key.go @@ -69,7 +69,7 @@ func KeyFromString(s string) Key { components: components, s: s, exactKey: s == SeparatorString || (s != "" && s[len(s)-1] == Separator), - noEnd: s == string(noEnd), + noEnd: s == noEnd, } } @@ -100,7 +100,7 @@ func (k Key) ExactKey() Key { // each component concatenated together via the [Separator]. func (k Key) String() string { if k.noEnd { - return string(noEnd) + return noEnd } return k.s diff --git a/lib/backend/memory/memory_test.go b/lib/backend/memory/memory_test.go index 442b97343a0ee..3d50435349a3e 100644 --- a/lib/backend/memory/memory_test.go +++ b/lib/backend/memory/memory_test.go @@ -21,6 +21,7 @@ package memory import ( "context" "os" + "slices" "strconv" "strings" "testing" @@ -128,3 +129,31 @@ func TestIterateRange(t *testing.T) { require.NoError(t, err) require.Equal(t, 20, scount) } + +func TestStreamRange(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + m, err := New(Config{}) + require.NoError(t, err) + defer m.Close() + + const N = 10 + for i := range 10 * N { + _, err := m.Put(ctx, backend.Item{ + Key: backend.NewKey("foo", strings.Repeat("a", i+1)), + Value: []byte("\x00"), + }) + require.NoError(t, err) + } + + var items []string + st := backend.StreamRange(ctx, m, backend.ExactKey("foo"), backend.RangeEnd(backend.ExactKey("foo")), N) + for st.Next() { + items = append(items, st.Item().Key.String()) + } + require.NoError(t, st.Done()) + + require.Len(t, items, 10*N) + require.True(t, slices.IsSorted(items)) +} diff --git a/lib/backend/sanitize.go b/lib/backend/sanitize.go index f5c31d1ea6c09..e84a5006e4340 100644 --- a/lib/backend/sanitize.go +++ b/lib/backend/sanitize.go @@ -58,7 +58,7 @@ func IsKeySafe(key Key) bool { components := key.Components() for i, k := range components { switch k { - case string(noEnd): + case noEnd: continue case ".", "..": return false