diff --git a/lib/backend/backend.go b/lib/backend/backend.go index 96f6d3a2bde26..a782edd2c2b59 100644 --- a/lib/backend/backend.go +++ b/lib/backend/backend.go @@ -167,18 +167,21 @@ func IterateRange(ctx context.Context, bk Backend, startKey, endKey Key, limit i // 2. allow individual backends to expose custom streaming methods s.t. the most performant // impl for a given backend may be used. func StreamRange(ctx context.Context, bk Backend, startKey, endKey Key, pageSize int) stream.Stream[Item] { - return stream.PageFunc[Item](func() ([]Item, error) { - if startKey.components == nil { + var done bool + return stream.PageFunc(func() ([]Item, error) { + if done { return nil, io.EOF } - rslt, err := bk.GetRange(ctx, startKey, endKey, pageSize) + rslt, err := bk.GetRange(ctx, startKey, endKey, pageSize+1) if err != nil { return nil, trace.Wrap(err) } - if len(rslt.Items) < pageSize { - startKey = Key{} + if len(rslt.Items) > pageSize { + startKey = rslt.Items[pageSize].Key + clear(rslt.Items[pageSize:]) + rslt.Items = rslt.Items[:pageSize] } else { - startKey = nextKey(rslt.Items[pageSize-1].Key) + done = true } return rslt.Items, nil }) @@ -298,10 +301,10 @@ func (p Params) GetString(key string) string { // NoLimit specifies no limits const NoLimit = 0 -// nextKey returns the next possible key. -// If used with a key prefix, this will return -// the end of the range for that key prefix. -func nextKey(key Key) Key { +const noEnd = "\x00" + +// RangeEnd returns end of the range for given key. +func RangeEnd(key Key) Key { end := make([]byte, len(key.s)) copy(end, key.s) for i := len(end) - 1; i >= 0; i-- { @@ -315,13 +318,6 @@ func nextKey(key Key) Key { return Key{noEnd: true} } -var noEnd = []byte{0} - -// RangeEnd returns end of the range for given key. -func RangeEnd(key Key) Key { - return nextKey(key) -} - // HostID is a derivation of a KeyedItem that allows the host id // to be included in the key. type HostID interface { @@ -345,7 +341,7 @@ func NextPaginationKey(ki KeyedItem) string { key = NewKey(ki.GetName()) } - return nextKey(key).String() + return RangeEnd(key).String() } // GetPaginationKey returns the pagination key given item. diff --git a/lib/backend/key.go b/lib/backend/key.go index 881704be2a254..ea79724e98eb5 100644 --- a/lib/backend/key.go +++ b/lib/backend/key.go @@ -69,7 +69,7 @@ func KeyFromString(s string) Key { components: components, s: s, exactKey: s == SeparatorString || (s != "" && s[len(s)-1] == Separator), - noEnd: s == string(noEnd), + noEnd: s == noEnd, } } @@ -100,7 +100,7 @@ func (k Key) ExactKey() Key { // each component concatenated together via the [Separator]. func (k Key) String() string { if k.noEnd { - return string(noEnd) + return noEnd } return k.s diff --git a/lib/backend/memory/memory_test.go b/lib/backend/memory/memory_test.go index 46e8b7532fdfa..b16611af22d46 100644 --- a/lib/backend/memory/memory_test.go +++ b/lib/backend/memory/memory_test.go @@ -21,6 +21,7 @@ package memory import ( "context" "os" + "slices" "strconv" "strings" "testing" @@ -127,3 +128,31 @@ func TestIterateRange(t *testing.T) { require.NoError(t, err) require.Equal(t, 20, scount) } + +func TestStreamRange(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + m, err := New(Config{}) + require.NoError(t, err) + defer m.Close() + + const N = 10 + for i := range 10 * N { + _, err := m.Put(ctx, backend.Item{ + Key: backend.NewKey("foo", strings.Repeat("a", i+1)), + Value: []byte("\x00"), + }) + require.NoError(t, err) + } + + var items []string + st := backend.StreamRange(ctx, m, backend.ExactKey("foo"), backend.RangeEnd(backend.ExactKey("foo")), N) + for st.Next() { + items = append(items, st.Item().Key.String()) + } + require.NoError(t, st.Done()) + + require.Len(t, items, 10*N) + require.True(t, slices.IsSorted(items)) +} diff --git a/lib/backend/sanitize.go b/lib/backend/sanitize.go index b188e247ae5d2..fed2ff257120a 100644 --- a/lib/backend/sanitize.go +++ b/lib/backend/sanitize.go @@ -57,7 +57,7 @@ func IsKeySafe(key Key) bool { components := key.Components() for i, k := range components { switch k { - case string(noEnd): + case noEnd: continue case ".", "..": return false